diff --git a/README.md b/README.md index 7c289e05a..76e33019b 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ If you have any questions about FunASR, please contact us by ## Contributors -|
|
|
|
| | | -|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:| +|
|
|
|
| | +|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:| ## Acknowledge @@ -82,7 +82,6 @@ If you have any questions about FunASR, please contact us by 3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training. 4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime. 5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime. -6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service. 6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5. ## License diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml deleted file mode 100644 index 68520ae23..000000000 --- a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml +++ /dev/null @@ -1,29 +0,0 @@ -lm: transformer -lm_conf: - pos_enc: null - embed_unit: 128 - att_unit: 512 - head: 8 - unit: 2048 - layer: 16 - dropout_rate: 0.1 - -# optimization related -grad_clip: 5.0 -batch_type: numel -batch_bins: 500000 # 4gpus * 500000 -accum_grad: 1 -max_epoch: 15 # 15epoch is enougth - -optim: adam -optim_conf: - lr: 0.001 -scheduler: warmuplr -scheduler_conf: - warmup_steps: 25000 - -best_model_criterion: -- - valid - - loss - - min -keep_nbest_models: 10 # 10 is good. diff --git a/egs/alimeeting/sa_asr/README.md b/egs/alimeeting/sa_asr/README.md new file mode 100644 index 000000000..1ae023a84 --- /dev/null +++ b/egs/alimeeting/sa_asr/README.md @@ -0,0 +1,86 @@ +# Get Started +Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf). +# Train +First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation)) +After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow: +```shell +dataset +|—— Eval_Ali_far +|—— Eval_Ali_near +|—— Test_Ali_far +|—— Test_Ali_near +|—— Train_Ali_far +|—— Train_Ali_near +``` +Then you can run this receipe by running: +```shell +bash run.sh --stage 0 --stop-stage 6 +``` +There are 8 stages in `run.sh`: +```shell +stage 0: Data preparation and remove the audio which is too long or too short. +stage 1: Speaker profile and CMVN Generation. +stage 2: Dictionary preparation. +stage 3: LM training (not supported). +stage 4: ASR Training. +stage 5: SA-ASR Training. +stage 6: Inference +stage 7: Inference with Test_2023_Ali_far +``` + +# Infer +1. Download the final test set and extracted +2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`. +3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`. +4. Run the `run.sh` as follow. +```shell +# Prepare test_2023 set +bash run.sh --stage 0 --stop-stage 1 +# Decode test_2023 set +bash run.sh --stage 7 --stop-stage 7 +``` +# Format of Final Submission +Finally, you need to submit a file called `text_spk_merge` with the following format: +```shell +Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ... +Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ... +... +``` +Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code. +# Baseline Results +The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy. + +| |SI-CER(%) |cp-CER(%) | +|:---------------|:------------:|----------:| +|oracle profile |32.72 |42.92 | +|cluster profile|32.73 |49.37 | + + +# Reference +N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417. \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml similarity index 100% rename from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml rename to egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml diff --git a/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml new file mode 100644 index 000000000..507ad3061 --- /dev/null +++ b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml @@ -0,0 +1,102 @@ +# network architecture +frontend: multichannelfrontend +frontend_conf: + fs: 16000 + window: hann + n_fft: 400 + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 1 + lfr_n: 1 + use_channel: 0 + +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + rel_pos_type: latest + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# ctc related +ctc_conf: + ignore_nan_grad: true + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +dataset_conf: + data_names: speech,text + data_types: sound,text + shuffle: True + shuffle_conf: + shuffle_size: 2048 + sort_size: 500 + batch_conf: + batch_type: token + batch_size: 7000 + num_workers: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 100 +val_scheduler_criterion: + - valid + - acc +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 diff --git a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml new file mode 100644 index 000000000..47bc6bdb6 --- /dev/null +++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml @@ -0,0 +1,131 @@ +# network architecture +frontend: multichannelfrontend +frontend_conf: + fs: 16000 + window: hann + n_fft: 400 + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 1 + lfr_n: 1 + use_channel: 0 + +# encoder related +asr_encoder: conformer +asr_encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +spk_encoder: resnet34_diar +spk_encoder_conf: + use_head_conv: true + batchnorm_momentum: 0.5 + use_head_maxpool: false + num_nodes_pooling_layer: 256 + layers_in_block: + - 3 + - 4 + - 6 + - 3 + filters_in_block: + - 32 + - 64 + - 128 + - 256 + pooling_type: statistic + num_nodes_resnet1: 256 + num_nodes_last_layer: 256 + batchnorm_momentum: 0.5 + +# decoder related +decoder: sa_decoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + asr_num_blocks: 6 + spk_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + spk_weight: 0.5 + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + max_spk_num: 4 + +ctc_conf: + ignore_nan_grad: true + +# minibatch related +dataset_conf: + data_names: speech,text,profile,text_id + data_types: sound,text,npy,text_int + shuffle: True + shuffle_conf: + shuffle_size: 2048 + sort_size: 500 + batch_conf: + batch_type: token + batch_size: 7000 + num_workers: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 60 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - acc + - max +- - valid + - acc_spk + - max +- - valid + - loss + - min +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 8000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 + diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh similarity index 74% rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh index c13ee429e..fd76837b1 100755 --- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh +++ b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh @@ -21,6 +21,8 @@ EOF SECONDS=0 tgt=Train #Train or Eval +min_wav_duration=0.1 +max_wav_duration=20 log "$0 $*" @@ -57,27 +59,24 @@ stage=1 stop_stage=4 mkdir -p $far_dir mkdir -p $near_dir +mkdir -p data/org if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then log "stage 1:process alimeeting near dir" find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist - awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid - find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist + awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid + find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist n1_wav=$(wc -l < $near_dir/wavlist) n2_text=$(wc -l < $near_dir/textgrid.flist) log near file found $n1_wav wav and $n2_text text. - paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp - - # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp - cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp + paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk - #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1 - #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk + local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1 @@ -97,9 +96,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then n2_text=$(wc -l < $far_dir/textgrid.flist) log far file found $n1_wav wav and $n2_text text. - paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp - - cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp + paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp python local/alimeeting_process_overlap_force.py --path $far_dir \ --no-overlap false --mars True \ @@ -119,28 +116,28 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - log "stage 3: finali data process" + log "stage 3: final data process" local/fix_data_dir.sh $near_dir local/fix_data_dir.sh $far_dir - local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near - local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near + local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far - sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo - sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo + sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo + sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo # remove space in text for x in ${tgt}_Ali_near ${tgt}_Ali_far; do - cp data/${x}/text data/${x}/text.org - paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ - > data/${x}/text - rm data/${x}/text.org + cp data/org/${x}/text data/org/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \ + > data/org/${x}/text + rm data/org/${x}/text.org done log "Successfully finished. [elapsed=${SECONDS}s]" fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - log "stage 4: process alimeeting far dir (single speaker by oracle time strap)" + log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)" cp -r $far_dir/* $far_single_speaker_dir mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist @@ -150,14 +147,15 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt ./local/fix_data_dir.sh $far_single_speaker_dir - local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker # remove space in text for x in ${tgt}_Ali_far_single_speaker; do - cp data/${x}/text data/${x}/text.org - paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ - > data/${x}/text - rm data/${x}/text.org + cp data/org/${x}/text data/org/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \ + > data/org/${x}/text + rm data/org/${x}/text.org done + rm -rf data/local log "Successfully finished. [elapsed=${SECONDS}s]" fi \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py similarity index 100% rename from egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py rename to egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py similarity index 100% rename from egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py rename to egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py diff --git a/egs/alimeeting/sa-asr/local/apply_map.pl b/egs/alimeeting/sa_asr/local/apply_map.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/apply_map.pl rename to egs/alimeeting/sa_asr/local/apply_map.pl diff --git a/egs/alimeeting/sa-asr/local/combine_data.sh b/egs/alimeeting/sa_asr/local/combine_data.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/combine_data.sh rename to egs/alimeeting/sa_asr/local/combine_data.sh diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.py b/egs/alimeeting/sa_asr/local/compute_cmvn.py new file mode 100755 index 000000000..d16563a96 --- /dev/null +++ b/egs/alimeeting/sa_asr/local/compute_cmvn.py @@ -0,0 +1,134 @@ +import argparse +import json +import os + +import numpy as np +import torchaudio +import torchaudio.compliance.kaldi as kaldi +import yaml +from funasr.models.frontend.default import DefaultFrontend +import torch + +def get_parser(): + parser = argparse.ArgumentParser( + description="computer global cmvn", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--dim", + default=80, + type=int, + help="feature dimension", + ) + parser.add_argument( + "--wav_path", + default=False, + required=True, + type=str, + help="the path of wav scps", + ) + parser.add_argument( + "--config_file", + type=str, + help="the config file for computing cmvn", + ) + parser.add_argument( + "--idx", + default=1, + required=True, + type=int, + help="index", + ) + return parser + + +def compute_fbank(wav_file, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + resample_rate=16000, + speed=1.0, + window_type="hamming"): + waveform, sample_rate = torchaudio.load(wav_file) + if resample_rate != sample_rate: + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, + new_freq=resample_rate)(waveform) + if speed != 1.0: + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, resample_rate, + [['speed', str(speed)], ['rate', str(resample_rate)]] + ) + + waveform = waveform * (1 << 15) + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + window_type=window_type, + sample_frequency=resample_rate) + + return mat.numpy() + + +def main(): + parser = get_parser() + args = parser.parse_args() + + wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx)) + cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx)) + + mean_stats = np.zeros(args.dim) + var_stats = np.zeros(args.dim) + total_frames = 0 + + # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader: + # for key, mat in ark_reader: + # mean_stats += np.sum(mat, axis=0) + # var_stats += np.sum(np.square(mat), axis=0) + # total_frames += mat.shape[0] + + with open(args.config_file) as f: + configs = yaml.safe_load(f) + frontend_configs = configs.get("frontend_conf", {}) + num_mel_bins = frontend_configs.get("n_mels", 80) + frame_length = frontend_configs.get("frame_length", 25) + frame_shift = frontend_configs.get("frame_shift", 10) + window_type = frontend_configs.get("window", "hamming") + resample_rate = frontend_configs.get("fs", 16000) + n_fft = frontend_configs.get("n_fft", "400") + use_channel = frontend_configs.get("use_channel", None) + assert num_mel_bins == args.dim + frontend = DefaultFrontend( + fs=resample_rate, + n_fft=n_fft, + win_length=frame_length * 16, + hop_length=frame_shift * 16, + window=window_type, + n_mels=num_mel_bins, + use_channel=use_channel, + ) + with open(wav_scp_file) as f: + lines = f.readlines() + for line in lines: + _, wav_file = line.strip().split() + wavform, _ = torchaudio.load(wav_file) + fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]])) + fbank = fbank.squeeze(0).numpy() + mean_stats += np.sum(fbank, axis=0) + var_stats += np.sum(np.square(fbank), axis=0) + total_frames += fbank.shape[0] + + cmvn_info = { + 'mean_stats': list(mean_stats.tolist()), + 'var_stats': list(var_stats.tolist()), + 'total_frames': total_frames + } + with open(cmvn_file, 'w') as fout: + fout.write(json.dumps(cmvn_info)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.sh b/egs/alimeeting/sa_asr/local/compute_cmvn.sh new file mode 100755 index 000000000..00d08d14c --- /dev/null +++ b/egs/alimeeting/sa_asr/local/compute_cmvn.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; +# Begin configuration section. +fbankdir= +nj=32 +cmd=./utils/run.pl +feats_dim=80 +config_file= +scale=1.0 + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +# shellcheck disable=SC2046 +head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale + +split_dir=${fbankdir}/cmvn/split_${nj}; +mkdir -p $split_dir +split_scps="" +for n in $(seq $nj); do + split_scps="$split_scps $split_dir/wav.$n.scp" +done +utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1; + +logdir=${fbankdir}/cmvn/log +$cmd JOB=1:$nj $logdir/cmvn.JOB.log \ + python local/compute_cmvn.py \ + --dim ${feats_dim} \ + --wav_path $split_dir \ + --config_file $config_file \ + --idx JOB \ + +python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn + +python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn + +echo "$0: Succeeded compute global cmvn" diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa_asr/local/compute_cpcer.py similarity index 100% rename from egs/alimeeting/sa-asr/local/compute_cpcer.py rename to egs/alimeeting/sa_asr/local/compute_cpcer.py diff --git a/egs/alimeeting/sa_asr/local/convert_model.py b/egs/alimeeting/sa_asr/local/convert_model.py new file mode 100644 index 000000000..f0f7997fc --- /dev/null +++ b/egs/alimeeting/sa_asr/local/convert_model.py @@ -0,0 +1,29 @@ +import codecs +import pdb +import sys +import torch + +char1 = sys.argv[1] +char2 = sys.argv[2] +model1 = torch.load(sys.argv[3], map_location='cpu') +model2_path = sys.argv[4] + +d_new = model1 +char1_list = [] +map_list = [] + + +with codecs.open(char1) as f: + for line in f.readlines(): + char1_list.append(line.strip()) + +with codecs.open(char2) as f: + for line in f.readlines(): + map_list.append(char1_list.index(line.strip())) +print(map_list) + +for k, v in d_new.items(): + if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight': + d_new[k] = v[map_list] + +torch.save(d_new, model2_path) diff --git a/egs/alimeeting/sa-asr/local/copy_data_dir.sh b/egs/alimeeting/sa_asr/local/copy_data_dir.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/copy_data_dir.sh rename to egs/alimeeting/sa_asr/local/copy_data_dir.sh diff --git a/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh b/egs/alimeeting/sa_asr/local/data/get_reco2dur.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/data/get_reco2dur.sh rename to egs/alimeeting/sa_asr/local/data/get_reco2dur.sh diff --git a/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh b/egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh rename to egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh diff --git a/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh b/egs/alimeeting/sa_asr/local/data/get_utt2dur.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/data/get_utt2dur.sh rename to egs/alimeeting/sa_asr/local/data/get_utt2dur.sh diff --git a/egs/alimeeting/sa-asr/local/data/split_data.sh b/egs/alimeeting/sa_asr/local/data/split_data.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/data/split_data.sh rename to egs/alimeeting/sa_asr/local/data/split_data.sh diff --git a/egs/alimeeting/sa_asr/local/download_and_untar.sh b/egs/alimeeting/sa_asr/local/download_and_untar.sh new file mode 100755 index 000000000..d98255915 --- /dev/null +++ b/egs/alimeeting/sa_asr/local/download_and_untar.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# 2017 Xingyu Na +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: data_aishell, resource_aishell." +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1; +fi + +part_ok=false +list="data_aishell resource_aishell" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1; +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1; +fi + +if [ -f $data/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0; +fi + +# sizes of the archive files in bytes. +sizes="15582913665 1246920" + +if [ -f $data/$part.tgz ]; then + size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') + size_ok=false + for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tgz + else + echo "$data/$part.tgz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tgz ]; then + if ! command -v wget >/dev/null; then + echo "$0: wget is not installed." + exit 1; + fi + full_url=$url/$part.tgz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + cd $data || exit 1 + if ! wget --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1; + fi +fi + +cd $data || exit 1 + +if ! tar -xvzf $part.tgz; then + echo "$0: error un-tarring archive $data/$part.tgz" + exit 1; +fi + +touch $data/$part/.complete + +if [ $part == "data_aishell" ]; then + cd $data/$part/wav || exit 1 + for wav in ./*.tar.gz; do + echo "Extracting wav from $wav" + tar -zxf $wav && rm $wav + done +fi + +echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" + +if $remove_archive; then + echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." + rm $data/$part.tgz +fi + +exit 0; diff --git a/egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py b/egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py similarity index 100% rename from egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py rename to egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa_asr/local/download_xvector_model.py similarity index 100% rename from egs/alimeeting/sa-asr/local/download_xvector_model.py rename to egs/alimeeting/sa_asr/local/download_xvector_model.py diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py similarity index 100% rename from egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py rename to egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py diff --git a/egs/alimeeting/sa-asr/local/fix_data_dir.sh b/egs/alimeeting/sa_asr/local/fix_data_dir.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/fix_data_dir.sh rename to egs/alimeeting/sa_asr/local/fix_data_dir.sh diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa_asr/local/format_wav_scp.py similarity index 100% rename from egs/alimeeting/sa-asr/local/format_wav_scp.py rename to egs/alimeeting/sa_asr/local/format_wav_scp.py diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa_asr/local/format_wav_scp.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/format_wav_scp.sh rename to egs/alimeeting/sa_asr/local/format_wav_scp.sh diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py similarity index 97% rename from egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py rename to egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py index c37abf9a0..859b72fce 100644 --- a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py +++ b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py @@ -63,7 +63,7 @@ if __name__ == "__main__": wav_scp_file = open(path+'/wav.scp', 'r') wav_scp = wav_scp_file.readlines() wav_scp_file.close() - raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r') raw_meeting_scp = raw_meeting_scp_file.readlines() raw_meeting_scp_file.close() segments_scp_file = open(raw_path + '/segments', 'r') @@ -92,8 +92,8 @@ if __name__ == "__main__": cluster_spk_num_file = open(path + '/cluster_spk_num', 'w') meeting_map = {} for line in raw_meeting_scp: - meeting = line.strip().split('\t')[0] - wav_path = line.strip().split('\t')[1] + meeting = line.strip().split(' ')[0] + wav_path = line.strip().split(' ')[1] wav = soundfile.read(wav_path)[0] # take the first channel if wav.ndim == 2: diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py similarity index 94% rename from egs/alimeeting/sa-asr/local/gen_oracle_embedding.py rename to egs/alimeeting/sa_asr/local/gen_oracle_embedding.py index 18286b42d..2a99b2b6b 100644 --- a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py +++ b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py @@ -9,7 +9,7 @@ import soundfile if __name__=="__main__": path = sys.argv[1] # dump2/raw/Eval_Ali_far raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker - raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r') raw_meeting_scp = raw_meeting_scp_file.readlines() raw_meeting_scp_file.close() segments_scp_file = open(raw_path + '/segments', 'r') @@ -22,8 +22,8 @@ if __name__=="__main__": raw_wav_map = {} for line in raw_meeting_scp: - meeting = line.strip().split('\t')[0] - wav_path = line.strip().split('\t')[1] + meeting = line.strip().split(' ')[0] + wav_path = line.strip().split(' ')[1] raw_wav_map[meeting] = wav_path spk_map = {} diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py similarity index 100% rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py similarity index 96% rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py index 186f1de9f..ff65a1f90 100644 --- a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py +++ b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py @@ -5,7 +5,7 @@ import sys if __name__=="__main__": - path = sys.argv[1] # dump2/raw/Train_Ali_far + path = sys.argv[1] wav_scp_file = open(path+"/wav.scp", 'r') wav_scp = wav_scp_file.readlines() wav_scp_file.close() @@ -29,7 +29,7 @@ if __name__=="__main__": line_list = line.strip().split(' ') meeting = line_list[0].split('-')[0] spk_id = line_list[0].split('-')[-1].split('_')[-1] - spk = meeting+'_' + spk_id + spk = meeting + '_' + spk_id global_spk_list.append(spk) if meeting in meeting_map_tmp.keys(): meeting_map_tmp[meeting].append(spk) diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh rename to egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py similarity index 94% rename from egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py rename to egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py index d900bb17a..488344fb0 100755 --- a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py +++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py @@ -30,8 +30,7 @@ def main(args): meetingid_map = {} for line in spk2utt: spkid = line.strip().split(" ")[0] - meeting_id_list = spkid.split("_")[:3] - meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2] + meeting_id = spkid.split("-")[0] if meeting_id not in meetingid_map: meetingid_map[meeting_id] = 1 else: diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa_asr/local/process_text_id.py similarity index 100% rename from egs/alimeeting/sa-asr/local/process_text_id.py rename to egs/alimeeting/sa_asr/local/process_text_id.py diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa_asr/local/process_text_spk_merge.py similarity index 100% rename from egs/alimeeting/sa-asr/local/process_text_spk_merge.py rename to egs/alimeeting/sa_asr/local/process_text_spk_merge.py diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py similarity index 100% rename from egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py rename to egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py diff --git a/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl rename to egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa_asr/local/text_format.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/text_format.pl rename to egs/alimeeting/sa_asr/local/text_format.pl diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa_asr/local/text_normalize.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/text_normalize.pl rename to egs/alimeeting/sa_asr/local/text_normalize.pl diff --git a/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl rename to egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl diff --git a/egs/alimeeting/sa-asr/local/validate_data_dir.sh b/egs/alimeeting/sa_asr/local/validate_data_dir.sh similarity index 100% rename from egs/alimeeting/sa-asr/local/validate_data_dir.sh rename to egs/alimeeting/sa_asr/local/validate_data_dir.sh diff --git a/egs/alimeeting/sa-asr/local/validate_text.pl b/egs/alimeeting/sa_asr/local/validate_text.pl similarity index 100% rename from egs/alimeeting/sa-asr/local/validate_text.pl rename to egs/alimeeting/sa_asr/local/validate_text.pl diff --git a/egs/alimeeting/sa_asr/path.sh b/egs/alimeeting/sa_asr/path.sh new file mode 100755 index 000000000..83ae507b7 --- /dev/null +++ b/egs/alimeeting/sa_asr/path.sh @@ -0,0 +1,6 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH +export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH diff --git a/egs/alimeeting/sa_asr/run.sh b/egs/alimeeting/sa_asr/run.sh new file mode 100755 index 000000000..43d0da13f --- /dev/null +++ b/egs/alimeeting/sa_asr/run.sh @@ -0,0 +1,435 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="6,7" +gpu_num=2 +count=1 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=8 +train_cmd=utils/run.pl +infer_cmd=utils/run.pl + +# general configuration +feats_dir="data" #feature output dictionary +exp_dir="exp" +lang=zh +token_type=char +type=sound +scp=wav.scp +speed_perturb="1.0" +min_wav_duration=0.1 +max_wav_duration=20 +profile_modes="cluster oracle" +stage=7 +stop_stage=7 + +# feature configuration +feats_dim=80 +nj=32 + +# data +raw_data= +data_url= + +# exp tag +tag="" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=Train_Ali_far +valid_set=Eval_Ali_far +test_sets="Test_Ali_far Eval_Ali_far" +test_2023="Test_2023_Ali_far_release" + +asr_config=conf/train_asr_conformer.yaml +sa_asr_config=conf/train_sa_asr_conformer.yaml +asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" +sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}" +inference_config=conf/decode_asr_rnn.yaml +inference_sa_asr_model=valid.acc_spk.ave.pb + +# you can set gpu num for decoding here +gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') + +if ${gpu_inference}; then + inference_nj=$[${ngpu}*${njob}] + _ngpu=1 +else + inference_nj=$njob + _ngpu=0 +fi + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Data preparation" + # Data preparation + ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration + ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration + ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration + # remove long/short data + for x in ${train_set} ${valid_set} ${test_sets}; do + cp -r ${feats_dir}/org/${x} ${feats_dir}/${x} + rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments + local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format wav --segments ${feats_dir}/org/${x}/segments \ + "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}" + _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))") + _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))") + <"${feats_dir}/${x}/utt2num_samples" \ + awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \ + >"${feats_dir}/${x}/utt2num_samples_rmls" + mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples + <"${feats_dir}/${x}/wav.scp" \ + utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \ + >"${feats_dir}/${x}/wav.scp_rmls" + mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp + <"${feats_dir}/${x}/text" \ + awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank" + mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text + local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}" + <"${feats_dir}/${x}/utt2spk_all_fifo" \ + utils/filter_scp.pl "${feats_dir}/${x}/text" \ + >"${feats_dir}/${x}/utt2spk_all_fifo_rmls" + mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo" + done + for x in ${test_2023}; do + local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format wav --segments ${feats_dir}/org/${x}/segments \ + "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}" + cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid + paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk + cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Speaker profile and CMVN Generation" + + mkdir -p "profile_log" + for x in "${train_set}" "${valid_set}" "${test_sets}"; do + # generate text_id spk2id + python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x} + echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id" + # generate text_id_train for sot + python local/process_text_id.py ${feats_dir}/${x} + echo "Successfully generate ${feats_dir}/${x}/text_id_train" + # generate oracle_embedding from single-speaker audio segment + echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log" + python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log" + echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)" + # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training) + if [ "${x}" = "${train_set}" ]; then + python local/gen_oracle_profile_padding.py ${feats_dir}/${x} + echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)" + else + python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x} + echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)" + fi + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then + echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log" + python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log" + echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)" + fi + # compute CMVN + if [ "${x}" = "${train_set}" ]; then + local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 + fi + done + + for x in "${test_2023}"; do + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log" + echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)" + done +fi + +token_list=${feats_dir}/${lang}_token_list/char/tokens.txt +echo "dictionary: ${token_list}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Dictionary Preparation" + mkdir -p ${feats_dir}/${lang}_token_list/char/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + echo "" >> ${token_list} +fi + +# LM Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: LM Training" +fi + +# ASR Training Stage +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Stage 4: ASR Training" + asr_exp=${exp_dir}/${asr_model_dir} + mkdir -p ${asr_exp} + mkdir -p ${asr_exp}/log + INIT_FILE=${asr_exp}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name asr \ + --model asr \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --split_with_space false \ + --token_type char \ + --token_list $token_list \ + --data_dir ${feats_dir} \ + --train_set ${train_set} \ + --valid_set ${valid_set} \ + --data_file_names "wav.scp,text" \ + --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ + --resume true \ + --output_dir ${exp_dir}/${asr_model_dir} \ + --config $asr_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + +fi + + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "SA-ASR training" + asr_exp=${exp_dir}/${asr_model_dir} + sa_asr_exp=${exp_dir}/${sa_asr_model_dir} + mkdir -p ${sa_asr_exp} + mkdir -p ${sa_asr_exp}/log + INIT_FILE=${sa_asr_exp}/ddp_init + if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then + ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp + ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp + fi + + if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then + # download xvector extractor model file + python local/download_xvector_model.py ${exp_dir} + echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" + fi + + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name asr \ + --model sa_asr \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --split_with_space false \ + --unused_parameters true \ + --token_type char \ + --resume true \ + --token_list $token_list \ + --data_dir ${feats_dir} \ + --train_set ${train_set} \ + --valid_set ${valid_set} \ + --data_file_names "wav.scp,text,profile.scp,text_id_train" \ + --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ + --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \ + --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \ + --output_dir ${exp_dir}/${sa_asr_model_dir} \ + --config $sa_asr_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "stage 6: Inference test sets" + for x in ${test_sets}; do + for profile_mode in ${profile_modes}; do + echo "decoding ${x} with ${profile_mode} profile" + sa_asr_exp=${exp_dir}/${sa_asr_model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/${x}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if [ $profile_mode = "oracle" ]; then + profile_scp=${profile_mode}_profile_nopadding.scp + else + profile_scp=${profile_mode}_profile_infer.scp + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --mc True \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --nbest 1 \ + --gpuid_list ${gpuid_list} \ + --allow_variable_data_keys true \ + --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + for f in token token_int score text text_id; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc + sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc + python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc + python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc + + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + python local/process_text_spk_merge.py ${_dir} + python local/process_text_spk_merge.py ${_data} + + python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer + tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt + cat ${_dir}/text.cpcer.txt + done + done +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + echo "stage 7: Inference test 2023" + for x in ${test_2023}; do + sa_asr_exp=${exp_dir}/${sa_asr_model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/${x}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --mc True \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --nbest 1 \ + --gpuid_list ${gpuid_list} \ + --allow_variable_data_keys true \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \ + --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + for f in token token_int score text text_id; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + + python local/process_text_spk_merge.py ${_dir} + + done +fi + + diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa_asr/utils similarity index 100% rename from egs/alimeeting/sa-asr/utils rename to egs/alimeeting/sa_asr/utils diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa_asr_deprecated/README.md similarity index 100% rename from egs/alimeeting/sa-asr/README.md rename to egs/alimeeting/sa_asr_deprecated/README.md diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa_asr_deprecated/asr_local.sh similarity index 100% rename from egs/alimeeting/sa-asr/asr_local.sh rename to egs/alimeeting/sa_asr_deprecated/asr_local.sh diff --git a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh similarity index 100% rename from egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh rename to egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh diff --git a/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml new file mode 100644 index 000000000..88fdbc20b --- /dev/null +++ b/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml @@ -0,0 +1,6 @@ +beam_size: 20 +penalty: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +ctc_weight: 0.6 +lm_weight: 0.3 diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml similarity index 100% rename from egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml rename to egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml similarity index 100% rename from egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml rename to egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml diff --git a/egs/alimeeting/sa_asr_deprecated/local b/egs/alimeeting/sa_asr_deprecated/local new file mode 120000 index 000000000..2ef6217d9 --- /dev/null +++ b/egs/alimeeting/sa_asr_deprecated/local @@ -0,0 +1 @@ +../sa_asr/local/ \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa_asr_deprecated/path.sh similarity index 100% rename from egs/alimeeting/sa-asr/path.sh rename to egs/alimeeting/sa_asr_deprecated/path.sh diff --git a/egs/alimeeting/sa-asr/run.sh b/egs/alimeeting/sa_asr_deprecated/run.sh similarity index 100% rename from egs/alimeeting/sa-asr/run.sh rename to egs/alimeeting/sa_asr_deprecated/run.sh diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh similarity index 100% rename from egs/alimeeting/sa-asr/run_m2met_2023_infer.sh rename to egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh diff --git a/egs/alimeeting/sa_asr_deprecated/utils b/egs/alimeeting/sa_asr_deprecated/utils new file mode 120000 index 000000000..fe070dd3a --- /dev/null +++ b/egs/alimeeting/sa_asr_deprecated/utils @@ -0,0 +1 @@ +../../aishell/transformer/utils \ No newline at end of file diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 140b4245c..c722ebc3a 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1636,8 +1636,10 @@ class Speech2TextSAASR: ) frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - if asr_train_args.frontend == 'wav_frontend': - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + from funasr.tasks.sa_asr import frontend_choices + if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend": + frontend_class = frontend_choices.get_class(asr_train_args.frontend) + frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() else: frontend_class = frontend_choices.get_class(asr_train_args.frontend) frontend = frontend_class(**asr_train_args.frontend_conf).eval() diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 367b9a815..656a9657a 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -619,7 +619,12 @@ def inference_paraformer_vad_punc( data_with_index = [(vadsegments[i], i) for i in range(n)] sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) results_sorted = [] - batch_size_token_ms = batch_size_token * 60 + + batch_size_token_ms = batch_size_token*60 + if speech2text.device == "cpu": + batch_size_token_ms = 0 + batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0]) + batch_size_token_ms_cum = 0 beg_idx = 0 for j, _ in enumerate(range(0, n)): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index f4fc0a7e6..1dc3fb523 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -301,7 +301,7 @@ def get_parser(): "--freeze_param", type=str, default=[], - nargs="*", + action="append", help="Freeze parameters", ) diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index d4a954cb6..200395d28 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -6,7 +6,6 @@ from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder from funasr.models.decoder.rnn_decoder import RNNDecoder -from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt from funasr.models.decoder.transformer_decoder import ( DynamicConvolution2DTransformerDecoder, # noqa: H301 @@ -20,17 +19,23 @@ from funasr.models.decoder.transformer_decoder import ( ) from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder +from funasr.models.decoder.rnnt_decoder import RNNTDecoder +from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder from funasr.models.e2e_asr import ASRModel from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.models.e2e_asr_mfcca import MFCCA -from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \ - ContextualParaformer + from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel + +from funasr.models.e2e_sa_asr import SAASRModel +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer + from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_uni_asr import UniASR from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder from funasr.models.encoder.mfcca_encoder import MFCCAEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34Diar from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.transformer_encoder import TransformerEncoder @@ -93,6 +98,8 @@ model_choices = ClassChoices( timestamp_prediction=TimestampPredictor, rnnt=TransducerModel, rnnt_unified=UnifiedTransducerModel, + sa_asr=SAASRModel, + ), default="asr", ) @@ -110,6 +117,27 @@ encoder_choices = ClassChoices( ), default="rnn", ) +asr_encoder_choices = ClassChoices( + "asr_encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + mfcca_enc=MFCCAEncoder, + ), + default="rnn", +) + +spk_encoder_choices = ClassChoices( + "spk_encoder", + classes=dict( + resnet34_diar=ResNet34Diar, + ), + default="resnet34_diar", +) encoder_choices2 = ClassChoices( "encoder2", classes=dict( @@ -134,6 +162,7 @@ decoder_choices = ClassChoices( paraformer_decoder_sanm=ParaformerSANMDecoder, paraformer_decoder_san=ParaformerDecoderSAN, contextual_paraformer_decoder=ContextualParaformerDecoder, + sa_decoder=SAAsrTransformerDecoder, ), default="rnn", ) @@ -225,6 +254,10 @@ class_choices_list = [ rnnt_decoder_choices, # --joint_network and --joint_network_conf joint_network_choices, + # --asr_encoder and --asr_encoder_conf + asr_encoder_choices, + # --spk_encoder and --spk_encoder_conf + spk_encoder_choices, ] @@ -247,7 +280,7 @@ def build_asr_model(args): # frontend if hasattr(args, "input_size") and args.input_size is None: frontend_class = frontend_choices.get_class(args.frontend) - if args.frontend == 'wav_frontend': + if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend': frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) else: frontend = frontend_class(**args.frontend_conf) @@ -425,6 +458,33 @@ def build_asr_model(args): joint_network=joint_network, **args.model_conf, ) + elif args.model == "sa_asr": + asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) + asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) + spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder) + spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=asr_encoder.output_size(), + **args.decoder_conf, + ) + ctc = CTC( + odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf + ) + + model_class = model_choices.get_class(args.model) + model = model_class( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + asr_encoder=asr_encoder, + spk_encoder=spk_encoder, + decoder=decoder, + ctc=ctc, + token_list=token_list, + **args.model_conf, + ) else: raise NotImplementedError("Not supported model: {}".format(args.model)) diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py index 830460721..e209d5192 100644 --- a/funasr/models/e2e_sa_asr.py +++ b/funasr/models/e2e_sa_asr.py @@ -40,7 +40,7 @@ else: yield -class ESPnetASRModel(FunASRModel): +class SAASRModel(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( @@ -51,10 +51,8 @@ class ESPnetASRModel(FunASRModel): frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], - preencoder: Optional[AbsPreEncoder], asr_encoder: AbsEncoder, spk_encoder: torch.nn.Module, - postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, spk_weight: float = 0.5, @@ -89,8 +87,6 @@ class ESPnetASRModel(FunASRModel): self.frontend = frontend self.specaug = specaug self.normalize = normalize - self.preencoder = preencoder - self.postencoder = postencoder self.asr_encoder = asr_encoder self.spk_encoder = spk_encoder @@ -293,10 +289,6 @@ class ESPnetASRModel(FunASRModel): if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) - # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) @@ -317,11 +309,6 @@ class ESPnetASRModel(FunASRModel): encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1) else: encoder_out_spk=encoder_out_spk_ori - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), @@ -337,7 +324,7 @@ class ESPnetASRModel(FunASRModel): ) if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens + return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk return encoder_out, encoder_out_lens, encoder_out_spk diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py index 19994f0e6..6718f3f6c 100644 --- a/funasr/models/frontend/default.py +++ b/funasr/models/frontend/default.py @@ -2,7 +2,7 @@ import copy from typing import Optional from typing import Tuple from typing import Union - +import logging import humanfriendly import numpy as np import torch @@ -14,6 +14,7 @@ from funasr.layers.stft import Stft from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.modules.frontends.frontend import Frontend from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.modules.nets_utils import make_pad_mask class DefaultFrontend(AbsFrontend): @@ -137,8 +138,6 @@ class DefaultFrontend(AbsFrontend): return input_stft, feats_lens - - class MultiChannelFrontend(AbsFrontend): """Conventional frontend structure for ASR. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN @@ -147,9 +146,9 @@ class MultiChannelFrontend(AbsFrontend): def __init__( self, fs: Union[int, str] = 16000, - n_fft: int = 512, - win_length: int = None, - hop_length: int = 128, + n_fft: int = 400, + frame_length: int = 25, + frame_shift: int = 10, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, @@ -160,10 +159,10 @@ class MultiChannelFrontend(AbsFrontend): htk: bool = False, frontend_conf: Optional[dict] = get_default_kwargs(Frontend), apply_stft: bool = True, - frame_length: int = None, - frame_shift: int = None, - lfr_m: int = None, - lfr_n: int = None, + use_channel: int = None, + lfr_m: int = 1, + lfr_n: int = 1, + cmvn_file: str = None ): assert check_argument_types() super().__init__() @@ -172,13 +171,14 @@ class MultiChannelFrontend(AbsFrontend): # Deepcopy (In general, dict shouldn't be used as default arg) frontend_conf = copy.deepcopy(frontend_conf) - self.hop_length = hop_length + self.win_length = frame_length * 16 + self.hop_length = frame_shift * 16 if apply_stft: self.stft = Stft( n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, + win_length=self.win_length, + hop_length=self.hop_length, center=center, window=window, normalized=normalized, @@ -202,7 +202,17 @@ class MultiChannelFrontend(AbsFrontend): htk=htk, ) self.n_mels = n_mels - self.frontend_type = "multichannelfrontend" + self.frontend_type = "default" + self.use_channel = use_channel + if self.use_channel is not None: + logging.info("use the channel %d" % (self.use_channel)) + else: + logging.info("random select channel") + self.cmvn_file = cmvn_file + if self.cmvn_file is not None: + mean, std = self._load_cmvn(self.cmvn_file) + self.register_buffer("mean", torch.from_numpy(mean)) + self.register_buffer("std", torch.from_numpy(std)) def output_size(self) -> int: return self.n_mels @@ -215,16 +225,29 @@ class MultiChannelFrontend(AbsFrontend): if self.stft is not None: input_stft, feats_lens = self._compute_stft(input, input_lengths) else: - if isinstance(input, ComplexTensor): - input_stft = input - else: - input_stft = ComplexTensor(input[..., 0], input[..., 1]) + input_stft = ComplexTensor(input[..., 0], input[..., 1]) feats_lens = input_lengths # 2. [Option] Speech enhancement if self.frontend is not None: assert isinstance(input_stft, ComplexTensor), type(input_stft) # input_stft: (Batch, Length, [Channel], Freq) input_stft, _, mask = self.frontend(input_stft, feats_lens) + + # 3. [Multi channel case]: Select a channel + if input_stft.dim() == 4: + # h: (B, T, C, F) -> h: (B, T, F) + if self.training: + if self.use_channel is not None: + input_stft = input_stft[:, :, self.use_channel, :] + + else: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] + else: + # Use the first channel + input_stft = input_stft[:, :, 0, :] + # 4. STFT -> Power spectrum # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) input_power = input_stft.real ** 2 + input_stft.imag ** 2 @@ -233,18 +256,27 @@ class MultiChannelFrontend(AbsFrontend): # input_power: (Batch, [Channel,] Length, Freq) # -> input_feats: (Batch, Length, Dim) input_feats, _ = self.logmel(input_power, feats_lens) - bt = input_feats.size(0) - if input_feats.dim() ==4: - channel_size = input_feats.size(2) - # batch * channel * T * D - #pdb.set_trace() - input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous() - # input_feats = input_feats.transpose(1,2) - # batch * channel - feats_lens = feats_lens.repeat(1,channel_size).squeeze() - else: - channel_size = 1 - return input_feats, feats_lens, channel_size + + # 6. Apply CMVN + if self.cmvn_file is not None: + if feats_lens is None: + feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1)) + self.mean = self.mean.to(input_feats.device, input_feats.dtype) + self.std = self.std.to(input_feats.device, input_feats.dtype) + mask = make_pad_mask(feats_lens, input_feats, 1) + + if input_feats.requires_grad: + input_feats = input_feats + self.mean + else: + input_feats += self.mean + if input_feats.requires_grad: + input_feats = input_feats.masked_fill(mask, 0.0) + else: + input_feats.masked_fill_(mask, 0.0) + + input_feats *= self.std + + return input_feats, feats_lens def _compute_stft( self, input: torch.Tensor, input_lengths: torch.Tensor @@ -258,4 +290,27 @@ class MultiChannelFrontend(AbsFrontend): # Change torch.Tensor to ComplexTensor # input_stft: (..., F, 2) -> (..., F) input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) - return input_stft, feats_lens \ No newline at end of file + return input_stft, feats_lens + + def _load_cmvn(self, cmvn_file): + with open(cmvn_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + means_list = [] + vars_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + add_shift_line = line_item[3:(len(line_item) - 1)] + means_list = list(add_shift_line) + continue + elif line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + rescale_line = line_item[3:(len(line_item) - 1)] + vars_list = list(rescale_line) + continue + means = np.array(means_list).astype(np.float) + vars = np.array(vars_list).astype(np.float) + return means, vars \ No newline at end of file diff --git a/funasr/runtime/java/FunasrWsClient.java b/funasr/runtime/java/FunasrWsClient.java new file mode 100644 index 000000000..ec55c9425 --- /dev/null +++ b/funasr/runtime/java/FunasrWsClient.java @@ -0,0 +1,344 @@ +// +// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights +// Reserved. MIT License (https://opensource.org/licenses/MIT) +// +/* + * // 2022-2023 by zhaomingwork@qq.com + */ +// java FunasrWsClient +// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS] +// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE] +package websocket; + +import java.io.*; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.*; +import java.util.Map; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.drafts.Draft; +import org.java_websocket.handshake.ServerHandshake; +import org.json.simple.JSONArray; +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** This example demonstrates how to connect to websocket server. */ +public class FunasrWsClient extends WebSocketClient { + + public class RecWavThread extends Thread { + private FunasrWsClient funasrClient; + + public RecWavThread(FunasrWsClient funasrClient) { + this.funasrClient = funasrClient; + } + + public void run() { + this.funasrClient.recWav(); + } + } + + private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class); + + public FunasrWsClient(URI serverUri, Draft draft) { + super(serverUri, draft); + } + + public FunasrWsClient(URI serverURI) { + super(serverURI); + } + + public FunasrWsClient(URI serverUri, Map httpHeaders) { + super(serverUri, httpHeaders); + } + + public void getSslContext(String keyfile, String certfile) { + // TODO + return; + } + + // send json at first time + public void sendJson( + String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) { + try { + + JSONObject obj = new JSONObject(); + obj.put("mode", mode); + JSONArray array = new JSONArray(); + String[] chunkList = strChunkSize.split(","); + for (int i = 0; i < chunkList.length; i++) { + array.add(Integer.valueOf(chunkList[i].trim())); + } + + obj.put("chunk_size", array); + obj.put("chunk_interval", new Integer(chunkInterval)); + obj.put("wav_name", wavName); + if (isSpeaking) { + obj.put("is_speaking", new Boolean(true)); + } else { + obj.put("is_speaking", new Boolean(false)); + } + logger.info("sendJson: " + obj); + // return; + + send(obj.toString()); + + return; + } catch (Exception e) { + e.printStackTrace(); + } + } + + // send json at end of wav + public void sendEof() { + try { + JSONObject obj = new JSONObject(); + + obj.put("is_speaking", new Boolean(false)); + + logger.info("sendEof: " + obj); + // return; + + send(obj.toString()); + iseof = true; + return; + } catch (Exception e) { + e.printStackTrace(); + } + } + + // function for rec wav file + public void recWav() { + sendJson(mode, strChunkSize, chunkInterval, wavName, true); + File file = new File(FunasrWsClient.wavPath); + + int chunkSize = sendChunkSize; + byte[] bytes = new byte[chunkSize]; + + int readSize = 0; + try (FileInputStream fis = new FileInputStream(file)) { + if (FunasrWsClient.wavPath.endsWith(".wav")) { + fis.read(bytes, 0, 44); //skip first 44 wav header + } + readSize = fis.read(bytes, 0, chunkSize); + while (readSize > 0) { + // send when it is chunk size + if (readSize == chunkSize) { + send(bytes); // send buf to server + + } else { + // send when at last or not is chunk size + byte[] tmpBytes = new byte[readSize]; + for (int i = 0; i < readSize; i++) { + tmpBytes[i] = bytes[i]; + } + send(tmpBytes); + } + // if not in offline mode, we simulate online stream by sleep + if (!mode.equals("offline")) { + Thread.sleep(Integer.valueOf(chunkSize / 32)); + } + + readSize = fis.read(bytes, 0, chunkSize); + } + + if (!mode.equals("offline")) { + // if not offline, we send eof and wait for 3 seconds to close + Thread.sleep(2000); + sendEof(); + Thread.sleep(3000); + close(); + } else { + // if offline, just send eof + sendEof(); + } + + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Override + public void onOpen(ServerHandshake handshakedata) { + + RecWavThread thread = new RecWavThread(this); + thread.start(); + } + + @Override + public void onMessage(String message) { + JSONObject jsonObject = new JSONObject(); + JSONParser jsonParser = new JSONParser(); + logger.info("received: " + message); + try { + jsonObject = (JSONObject) jsonParser.parse(message); + logger.info("text: " + jsonObject.get("text")); + } catch (org.json.simple.parser.ParseException e) { + e.printStackTrace(); + } + if (iseof && mode.equals("offline")) { + close(); + } + } + + @Override + public void onClose(int code, String reason, boolean remote) { + + logger.info( + "Connection closed by " + + (remote ? "remote peer" : "us") + + " Code: " + + code + + " Reason: " + + reason); + } + + @Override + public void onError(Exception ex) { + logger.info("ex: " + ex); + ex.printStackTrace(); + // if the error is fatal then onClose will be called additionally + } + + private boolean iseof = false; + public static String wavPath; + static String mode = "online"; + static String strChunkSize = "5,10,5"; + static int chunkInterval = 10; + static int sendChunkSize = 1920; + + String wavName = "javatest"; + + public static void main(String[] args) throws URISyntaxException { + ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true); + parser + .addArgument("--port") + .help("Port on which to listen.") + .setDefault("8889") + .type(String.class) + .required(false); + parser + .addArgument("--host") + .help("the IP address of server.") + .setDefault("127.0.0.1") + .type(String.class) + .required(false); + parser + .addArgument("--audio_in") + .help("wav path for decoding.") + .setDefault("asr_example.wav") + .type(String.class) + .required(false); + parser + .addArgument("--num_threads") + .help("num of threads for test.") + .setDefault(1) + .type(Integer.class) + .required(false); + parser + .addArgument("--chunk_size") + .help("chunk size for asr.") + .setDefault("5, 10, 5") + .type(String.class) + .required(false); + parser + .addArgument("--chunk_interval") + .help("chunk for asr.") + .setDefault(10) + .type(Integer.class) + .required(false); + + parser + .addArgument("--mode") + .help("mode for asr.") + .setDefault("offline") + .type(String.class) + .required(false); + String srvIp = ""; + String srvPort = ""; + String wavPath = ""; + int numThreads = 1; + String chunk_size = ""; + int chunk_interval = 10; + String strmode = "offline"; + + try { + Namespace ns = parser.parseArgs(args); + srvIp = ns.get("host"); + srvPort = ns.get("port"); + wavPath = ns.get("audio_in"); + numThreads = ns.get("num_threads"); + chunk_size = ns.get("chunk_size"); + chunk_interval = ns.get("chunk_interval"); + strmode = ns.get("mode"); + System.out.println(srvPort); + + } catch (ArgumentParserException ex) { + ex.getParser().handleError(ex); + return; + } + + FunasrWsClient.strChunkSize = chunk_size; + FunasrWsClient.chunkInterval = chunk_interval; + FunasrWsClient.wavPath = wavPath; + FunasrWsClient.mode = strmode; + System.out.println( + "serIp=" + + srvIp + + ",srvPort=" + + srvPort + + ",wavPath=" + + wavPath + + ",strChunkSize" + + strChunkSize); + + class ClientThread implements Runnable { + + String srvIp; + String srvPort; + + ClientThread(String srvIp, String srvPort, String wavPath) { + this.srvIp = srvIp; + this.srvPort = srvPort; + } + + public void run() { + try { + + int RATE = 16000; + String[] chunkList = strChunkSize.split(","); + int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval; + int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size); + int stride = + Integer.valueOf( + 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2); + System.out.println("chunk_size:" + String.valueOf(int_chunk_size)); + System.out.println("CHUNK:" + CHUNK); + System.out.println("stride:" + String.valueOf(stride)); + FunasrWsClient.sendChunkSize = CHUNK * 2; + + String wsAddress = "ws://" + srvIp + ":" + srvPort; + + FunasrWsClient c = new FunasrWsClient(new URI(wsAddress)); + + c.connect(); + + System.out.println("wsAddress:" + wsAddress); + } catch (Exception e) { + e.printStackTrace(); + System.out.println("e:" + e); + } + } + } + for (int i = 0; i < numThreads; i++) { + System.out.println("Thread1 is running..."); + Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath)); + t.start(); + } + } +} diff --git a/funasr/runtime/java/Makefile b/funasr/runtime/java/Makefile new file mode 100644 index 000000000..9a70ca5fc --- /dev/null +++ b/funasr/runtime/java/Makefile @@ -0,0 +1,76 @@ + +ENTRY_POINT = ./ + + + + +WEBSOCKET_DIR:= ./ +WEBSOCKET_FILES = \ + $(WEBSOCKET_DIR)/FunasrWsClient.java \ + + + +LIB_BUILD_DIR = ./lib + + + + +JAVAC = javac + +BUILD_DIR = build + + +RUNJFLAGS = -Dfile.encoding=utf-8 + + +vpath %.class $(BUILD_DIR) +vpath %.java src + + + + +rebuild: clean all + +.PHONY: clean run downjar + +downjar: + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/ + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/ + #wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/ + wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/ + wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/ + wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/ + rm -frv build + mkdir build +clean: + rm -frv $(BUILD_DIR)/* + rm -frv $(LIB_BUILD_DIR)/* + mkdir -p $(BUILD_DIR) + mkdir -p ./lib + + + + + + +runclient: + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass + + + +buildwebsocket: $(WEBSOCKET_FILES:.java=.class) + + +%.class: %.java + + $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $< + +packjar: + jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) . + +all: clean buildlib packjar buildfile buildmic downjar buildwebsocket + + + + + diff --git a/funasr/runtime/java/readme.md b/funasr/runtime/java/readme.md new file mode 100644 index 000000000..406a21a0a --- /dev/null +++ b/funasr/runtime/java/readme.md @@ -0,0 +1,66 @@ +# Client for java websocket example + + + +## Building for Linux/Unix + +### install java environment +```shell +# in ubuntu +apt-get install openjdk-11-jdk +``` + + + +### Build and run by make + + +```shell +cd funasr/runtime/java +# download java lib +make downjar +# compile +make buildwebsocket +# run client +make runclient + +``` + +## Run java websocket client by shell + +```shell +# full command refer to Makefile runclient +usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS] + [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE] + +Where: + --host + (required) server-ip + + --port + (required) port + + --audio_in + (required) the wav or pcm file path + + --num_threads + thread number for test + + --mode + asr mode, support "offline" "online" "2pass" + + + +example: +FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass + +result json, example like: +{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"javatest"} +``` + + +## Acknowledge +1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). +2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example. + + diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp index d2692ceb5..a4ee7f7df 100644 --- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp +++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp @@ -65,7 +65,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector wav_list, vector wa n_total_length += snippet_time; FunASRFreeResult(result); }else{ - LOG(ERROR) << ("No return data!\n"); + LOG(ERROR) << wav_ids[i] << (": No return data!\n"); } } { diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp index 58eec2540..2ee41140f 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp +++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp @@ -18,6 +18,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p try{ m_session = std::make_unique(env_, punc_model.c_str(), session_options); + LOG(INFO) << "Successfully load model from " << punc_model; } catch (std::exception const &e) { LOG(ERROR) << "Error when load punc onnx model: " << e.what(); diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp index 1957a12d4..b605ffff6 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.cpp +++ b/funasr/runtime/onnxruntime/src/paraformer.cpp @@ -33,6 +33,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn try { m_session = std::make_unique(env_, am_model.c_str(), session_options); + LOG(INFO) << "Successfully load model from " << am_model; } catch (std::exception const &e) { LOG(ERROR) << "Error when load am onnx model: " << e.what(); exit(0); diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py index 64e363f81..246d67895 100644 --- a/funasr/runtime/python/onnxruntime/setup.py +++ b/funasr/runtime/python/onnxruntime/setup.py @@ -13,7 +13,7 @@ def get_readme(): MODULE_NAME = 'funasr_onnx' -VERSION_NUM = '0.1.0' +VERSION_NUM = '0.1.1' setuptools.setup( name=MODULE_NAME, @@ -31,7 +31,7 @@ setuptools.setup( "onnxruntime>=1.7.0", "scipy", "numpy>=1.19.3", - "typeguard", + "typeguard==2.13.3", "kaldi-native-fbank", "PyYAML>=5.1.2", "funasr", diff --git a/funasr/runtime/ssl_key/readme.md b/funasr/runtime/ssl_key/readme.md index a5989e6ec..8a48dd3e2 100644 --- a/funasr/runtime/ssl_key/readme.md +++ b/funasr/runtime/ssl_key/readme.md @@ -3,7 +3,7 @@ generated certificate may not suitable for all browsers due to security concerns ```shell ### 1) Generate a private key -openssl genrsa -des3 -out server.key 1024 +openssl genrsa -des3 -out server.key 2048 ### 2) Generate a csr file openssl req -new -key server.key -out server.csr @@ -14,4 +14,4 @@ openssl rsa -in server.key.org -out server.key ### 4) Generated a crt file, valid for 1 year openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt -``` \ No newline at end of file +``` diff --git a/funasr/runtime/ssl_key/server.crt b/funasr/runtime/ssl_key/server.crt index 808b73e6e..5a5079d08 100644 --- a/funasr/runtime/ssl_key/server.crt +++ b/funasr/runtime/ssl_key/server.crt @@ -1,15 +1,21 @@ -----BEGIN CERTIFICATE----- -MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx -CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n -MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh -YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD -TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH -YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG -SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS -NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc -38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL -mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA -dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB -uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS -hXtozVCgkSJzX6uD +MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x +CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n +MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh +bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX +DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx +EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs +aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD +qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW +f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1 +EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz +Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr +LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj +AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz +D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc +dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq +Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC +7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs +Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw= -----END CERTIFICATE----- diff --git a/funasr/runtime/ssl_key/server.key b/funasr/runtime/ssl_key/server.key index aac8b2646..8efdcb832 100644 --- a/funasr/runtime/ssl_key/server.key +++ b/funasr/runtime/ssl_key/server.key @@ -1,15 +1,27 @@ -----BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u -cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso -OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB -AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq -pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc -2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE -NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD -M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3 -7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt -w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z -fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l -SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO -uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB +MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k +vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4 +5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD +jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT +ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G +NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs +MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+ +vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw +jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2 +eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i +C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g +6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua +jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU +qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb +aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY +zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9 +xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg +9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp +3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG +P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1 +7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE +cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR +dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh +cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ +W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa -----END RSA PRIVATE KEY----- diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp index 4a3c7516d..8b59000fc 100644 --- a/funasr/runtime/websocket/funasr-wss-client.cpp +++ b/funasr/runtime/websocket/funasr-wss-client.cpp @@ -91,7 +91,6 @@ class WebsocketClient { using websocketpp::lib::placeholders::_1; m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1)); m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1)); - // m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1)); m_client.set_message_handler( [this](websocketpp::connection_hdl hdl, message_ptr msg) { @@ -218,7 +217,7 @@ class WebsocketClient { } } if (wait) { - LOG(INFO) << "wait.." << m_open; + // LOG(INFO) << "wait.." << m_open; WaitABit(); continue; } @@ -292,7 +291,7 @@ int main(int argc, char* argv[]) { false, 1, "int"); TCLAP::ValueArg is_ssl_( "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", - false, 0, "int"); + false, 1, "int"); cmd.add(server_ip_); cmd.add(port_); diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 824485631..7338513cb 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -38,6 +38,7 @@ from funasr.models.decoder.transformer_decoder import ( from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder +from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder from funasr.models.e2e_asr import ASRModel from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.joint_net.joint_network import JointNetwork @@ -45,6 +46,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, Paraf from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA +from funasr.models.e2e_sa_asr import SAASRModel from funasr.models.e2e_uni_asr import UniASR from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel from funasr.models.encoder.abs_encoder import AbsEncoder @@ -54,6 +56,7 @@ from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.transformer_encoder import TransformerEncoder from funasr.models.encoder.mfcca_encoder import MFCCAEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34Diar from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.default import MultiChannelFrontend @@ -134,6 +137,7 @@ model_choices = ClassChoices( timestamp_prediction=TimestampPredictor, rnnt=TransducerModel, rnnt_unified=UnifiedTransducerModel, + sa_asr=SAASRModel, ), type_check=FunASRModel, default="asr", @@ -175,6 +179,27 @@ encoder_choices2 = ClassChoices( type_check=AbsEncoder, default="rnn", ) +asr_encoder_choices = ClassChoices( + "asr_encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + mfcca_enc=MFCCAEncoder, + ), + type_check=AbsEncoder, + default="rnn", +) +spk_encoder_choices = ClassChoices( + "spk_encoder", + classes=dict( + resnet34_diar=ResNet34Diar, + ), + default="resnet34_diar", +) postencoder_choices = ClassChoices( name="postencoder", classes=dict( @@ -197,6 +222,7 @@ decoder_choices = ClassChoices( paraformer_decoder_sanm=ParaformerSANMDecoder, paraformer_decoder_san=ParaformerDecoderSAN, contextual_paraformer_decoder=ContextualParaformerDecoder, + sa_decoder=SAAsrTransformerDecoder, ), type_check=AbsDecoder, default="rnn", @@ -329,6 +355,12 @@ class ASRTask(AbsTask): default=True, help="whether to split text using ", ) + group.add_argument( + "--max_spk_num", + type=int_or_none, + default=None, + help="A text mapping int-id to token", + ) group.add_argument( "--seg_dict_file", type=str, @@ -1495,3 +1527,123 @@ class ASRTransducerTask(ASRTask): #assert check_return_type(model) return model + + +class ASRTaskSAASR(ASRTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --encoder and --encoder_conf + # --asr_encoder and --asr_encoder_conf + asr_encoder_choices, + # --spk_encoder and --spk_encoder_conf + spk_encoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size}") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend": + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 5. Encoder + asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) + asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) + spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder) + spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf) + + # 7. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=asr_encoder.output_size(), + **args.decoder_conf, + ) + + # 8. CTC + ctc = CTC( + odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf + ) + + # import ipdb;ipdb.set_trace() + # 9. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("asr") + model = model_class( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + asr_encoder=asr_encoder, + spk_encoder=spk_encoder, + decoder=decoder, + ctc=ctc, + token_list=token_list, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py index 4769758b3..957948336 100644 --- a/funasr/tasks/sa_asr.py +++ b/funasr/tasks/sa_asr.py @@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import ( from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder -from funasr.models.e2e_sa_asr import ESPnetASRModel +from funasr.models.e2e_sa_asr import SAASRModel from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA @@ -120,7 +120,7 @@ normalize_choices = ClassChoices( model_choices = ClassChoices( "model", classes=dict( - asr=ESPnetASRModel, + asr=SAASRModel, uniasr=UniASR, paraformer=Paraformer, paraformer_bert=ParaformerBert, @@ -620,4 +620,4 @@ class ASRTask(AbsTask): initialize(model, args.init) assert check_return_type(model) - return model + return model \ No newline at end of file diff --git a/funasr/version.txt b/funasr/version.txt index ee6cdce3c..b61604874 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -0.6.1 +0.6.2