diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml new file mode 100644 index 000000000..459a7414f --- /dev/null +++ b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml @@ -0,0 +1,121 @@ +model: sond +model_conf: + lsm_weight: 0.0 + length_normalized_loss: true + max_spk_num: 16 + +# speech encoder +encoder: ecapa_tdnn +encoder_conf: + # pass by model, equal to feature dim + # input_size: 80 + pool_size: 20 + stride: 1 +speaker_encoder: conv +speaker_encoder_conf: + input_units: 256 + num_layers: 3 + num_units: 256 + kernel_size: 1 + dropout_rate: 0.0 + position_encoder: null + out_units: 256 + out_norm: false + auxiliary_states: false + tf2torch_tensor_name_prefix_torch: speaker_encoder + tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder +ci_scorer: dot +ci_scorer_conf: {} +cd_scorer: san +cd_scorer_conf: + input_size: 512 + output_size: 512 + out_units: 1 + attention_heads: 4 + linear_units: 1024 + num_blocks: 4 + dropout_rate: 0.0 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + # use string "null" to remove input layer + input_layer: "null" + pos_enc_class: null + normalize_before: true + tf2torch_tensor_name_prefix_torch: cd_scorer + tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer +# post net +decoder: fsmn +decoder_conf: + in_units: 32 + out_units: 2517 + filter_size: 31 + fsmn_num_layers: 6 + dnn_num_layers: 1 + num_memory_units: 512 + ffn_inner_dim: 512 + dropout_rate: 0.0 + tf2torch_tensor_name_prefix_torch: decoder + tf2torch_tensor_name_prefix_tf: EAND/post_net +frontend: wav_frontend +frontend_conf: + fs: 16000 + window: povey + n_mels: 80 + frame_length: 25 + frame_shift: 10 + filter_length_min: -1 + filter_length_max: -1 + lfr_m: 1 + lfr_n: 1 + dither: 0.0 + snip_edges: false + +# minibatch related +batch_type: length +# 16s * 16k * 16 samples +batch_bins: 4096000 +num_workers: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 50 +val_scheduler_criterion: + - valid + - acc +best_model_criterion: +- - valid + - der + - min +- - valid + - forward_steps + - max +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 10000 + +# without spec aug +specaug: null +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 + +log_interval: 50 +# without normalize +normalize: None diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh new file mode 100755 index 000000000..3b319f46e --- /dev/null +++ b/egs/mars/sd/local_run.sh @@ -0,0 +1,171 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="6,7" +gpu_num=2 +count=1 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=5 +train_cmd=utils/run.pl +infer_cmd=utils/run.pl + +# general configuration +feats_dir="." #feature output dictionary +exp_dir="." +lang=zh +dumpdir=dump/raw +feats_type=raw +token_type=char +scp=wav.scp +type=kaldi_ark +stage=3 +stop_stage=4 + +# feature configuration +feats_dim= +sample_frequency=16000 +nj=32 +speed_perturb= + +# exp tag +tag="exp1" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train +valid_set=dev +test_sets="dev test" + +asr_config=conf/train_asr_conformer.yaml +model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" + +inference_config=conf/decode_asr_transformer.yaml +inference_asr_model=valid.acc.ave_10best.pth + +# you can set gpu num for decoding here +gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') + +if ${gpu_inference}; then + inference_nj=$[${ngpu}*${njob}] + _ngpu=1 +else + inference_nj=$njob + _ngpu=0 +fi + +feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} +feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} +feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} + +# Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Training" + mkdir -p ${exp_dir}/exp/${model_dir} + mkdir -p ${exp_dir}/exp/${model_dir}/log + INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + asr_train.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --token_type char \ + --token_list $token_list \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ + --resume true \ + --output_dir ${exp_dir}/exp/${model_dir} \ + --config $asr_config \ + --input_size $feats_dim \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Testing Stage +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: Inference" + for dset in ${test_sets}; do + asr_exp=${exp_dir}/exp/${model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/${dumpdir}/${dset}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --gpuid_list ${gpuid_list} \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode asr \ + ${_opts} + + for f in token token_int score text; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + done +fi + diff --git a/egs/mars/sd/path.sh b/egs/mars/sd/path.sh new file mode 100755 index 000000000..7972642d0 --- /dev/null +++ b/egs/mars/sd/path.sh @@ -0,0 +1,5 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py new file mode 100644 index 000000000..b207f2d3a --- /dev/null +++ b/egs/mars/sd/scripts/calculate_shapes.py @@ -0,0 +1,45 @@ +import logging +import numpy as np +import soundfile +import kaldiio +from funasr.utils.job_runner import MultiProcessRunnerV3 +from funasr.utils.misc import load_scp_as_list, load_scp_as_dict +import os +import argparse +from collections import OrderedDict + + +class MyRunner(MultiProcessRunnerV3): + + def prepare(self, parser: argparse.ArgumentParser): + parser.add_argument("--input_scp", type=str, required=True) + parser.add_argument("--out_path") + args = parser.parse_args() + + if not os.path.exists(os.path.dirname(args.out_path)): + os.makedirs(os.path.dirname(args.out_path)) + + task_list = load_scp_as_list(args.input_scp) + return task_list, None, args + + def post(self, result_list, args): + fd = open(args.out_path, "wt", encoding="utf-8") + for results in result_list: + for uttid, shape in results: + fd.write("{} {}\n".format(uttid, ",".join(shape))) + fd.close() + + +def process(task_args): + task_idx, task_list, _, args = task_args + rst = [] + for uttid, file_path in task_list: + data = kaldiio.load_mat(file_path) + shape = [str(x) for x in data.shape] + rst.append((uttid, shape)) + return rst + + +if __name__ == '__main__': + my_runner = MyRunner(process) + my_runner.run() diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index f55bbf65d..ad54723a9 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -90,6 +90,7 @@ class DiarSondModel(AbsESPnetModel): self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]) self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight self.inter_score_loss_weight = inter_score_loss_weight + self.forward_steps = 0 def generate_pse_embedding(self): embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) @@ -123,7 +124,7 @@ class DiarSondModel(AbsESPnetModel): """ assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) batch_size = speech.shape[0] - + self.forward_steps = self.forward_steps + 1 # 1. Network forward pred, inter_outputs = self.prediction_forward( speech, speech_lengths, @@ -198,6 +199,7 @@ class DiarSondModel(AbsESPnetModel): cf=cf, acc=acc, der=der, + forward_steps=self.forward_steps, ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) @@ -262,8 +264,10 @@ class DiarSondModel(AbsESPnetModel): self, speech: torch.Tensor, speech_lengths: torch.Tensor, - spk_labels: torch.Tensor = None, - spk_labels_lengths: torch.Tensor = None, + profile: torch.Tensor = None, + profile_lengths: torch.Tensor = None, + binary_labels: torch.Tensor = None, + binary_labels_lengths: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths} diff --git a/funasr/models/encoder/ecapa_tdnn_encoder.py b/funasr/models/encoder/ecapa_tdnn_encoder.py index 3a75e5c31..878a3c032 100644 --- a/funasr/models/encoder/ecapa_tdnn_encoder.py +++ b/funasr/models/encoder/ecapa_tdnn_encoder.py @@ -528,8 +528,6 @@ class ECAPA_TDNN(torch.nn.Module): Arguments --------- - device : str - Device used, e.g., "cpu" or "cuda". activation : torch class A class for constructing the activation layers. channels : list of ints @@ -555,7 +553,6 @@ class ECAPA_TDNN(torch.nn.Module): def __init__( self, input_size, - device="cpu", lin_neurons=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index 9a439453f..7f154ef6c 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -24,6 +24,7 @@ from funasr.layers.utterance_mvn import UtteranceMVN from funasr.layers.label_aggregation import LabelAggregate from funasr.models.ctc import CTC from funasr.models.encoder.resnet34_encoder import ResNet34Diar +from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder @@ -123,6 +124,7 @@ encoder_choices = ClassChoices( resnet34=ResNet34Diar, sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, + epaca_dtnn=ECAPA_TDNN, ), type_check=AbsEncoder, default="resnet34",