diff --git a/README.md b/README.md
index 7c289e05a..76e33019b 100644
--- a/README.md
+++ b/README.md
@@ -72,8 +72,8 @@ If you have any questions about FunASR, please contact us by
## Contributors
-|
|
|
|
-|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:|
+| |
|
+|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|
## Acknowledge
@@ -82,7 +82,6 @@ If you have any questions about FunASR, please contact us by
3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training.
4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime.
5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime.
-6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5.
## License
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
deleted file mode 100644
index 68520ae23..000000000
--- a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-lm: transformer
-lm_conf:
- pos_enc: null
- embed_unit: 128
- att_unit: 512
- head: 8
- unit: 2048
- layer: 16
- dropout_rate: 0.1
-
-# optimization related
-grad_clip: 5.0
-batch_type: numel
-batch_bins: 500000 # 4gpus * 500000
-accum_grad: 1
-max_epoch: 15 # 15epoch is enougth
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 25000
-
-best_model_criterion:
-- - valid
- - loss
- - min
-keep_nbest_models: 10 # 10 is good.
diff --git a/egs/alimeeting/sa_asr/README.md b/egs/alimeeting/sa_asr/README.md
new file mode 100644
index 000000000..1ae023a84
--- /dev/null
+++ b/egs/alimeeting/sa_asr/README.md
@@ -0,0 +1,86 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
+# Train
+First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
+After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
+```shell
+dataset
+|—— Eval_Ali_far
+|—— Eval_Ali_near
+|—— Test_Ali_far
+|—— Test_Ali_near
+|—— Train_Ali_far
+|—— Train_Ali_near
+```
+Then you can run this receipe by running:
+```shell
+bash run.sh --stage 0 --stop-stage 6
+```
+There are 8 stages in `run.sh`:
+```shell
+stage 0: Data preparation and remove the audio which is too long or too short.
+stage 1: Speaker profile and CMVN Generation.
+stage 2: Dictionary preparation.
+stage 3: LM training (not supported).
+stage 4: ASR Training.
+stage 5: SA-ASR Training.
+stage 6: Inference
+stage 7: Inference with Test_2023_Ali_far
+```
+
+# Infer
+1. Download the final test set and extracted
+2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
+3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`.
+4. Run the `run.sh` as follow.
+```shell
+# Prepare test_2023 set
+bash run.sh --stage 0 --stop-stage 1
+# Decode test_2023 set
+bash run.sh --stage 7 --stop-stage 7
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
+
+| |SI-CER(%) |cp-CER(%) |
+|:---------------|:------------:|----------:|
+|oracle profile |32.72 |42.92 |
+|cluster profile|32.73 |49.37 |
+
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
rename to egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
diff --git a/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
new file mode 100644
index 000000000..507ad3061
--- /dev/null
+++ b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
@@ -0,0 +1,102 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+ fs: 16000
+ window: hann
+ n_fft: 400
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+ use_channel: 0
+
+# encoder related
+encoder: conformer
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ rel_pos_type: latest
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+ ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
diff --git a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
new file mode 100644
index 000000000..47bc6bdb6
--- /dev/null
+++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
@@ -0,0 +1,131 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+ fs: 16000
+ window: hann
+ n_fft: 400
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+ use_channel: 0
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+ use_head_conv: true
+ batchnorm_momentum: 0.5
+ use_head_maxpool: false
+ num_nodes_pooling_layer: 256
+ layers_in_block:
+ - 3
+ - 4
+ - 6
+ - 3
+ filters_in_block:
+ - 32
+ - 64
+ - 128
+ - 256
+ pooling_type: statistic
+ num_nodes_resnet1: 256
+ num_nodes_last_layer: 256
+ batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ asr_num_blocks: 6
+ spk_num_blocks: 3
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+ spk_weight: 0.5
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+ max_spk_num: 4
+
+ctc_conf:
+ ignore_nan_grad: true
+
+# minibatch related
+dataset_conf:
+ data_names: speech,text,profile,text_id
+ data_types: sound,text,npy,text_int
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - acc
+ - max
+- - valid
+ - acc_spk
+ - max
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
similarity index 74%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
index c13ee429e..fd76837b1 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
+++ b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
@@ -21,6 +21,8 @@ EOF
SECONDS=0
tgt=Train #Train or Eval
+min_wav_duration=0.1
+max_wav_duration=20
log "$0 $*"
@@ -57,27 +59,24 @@ stage=1
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
+mkdir -p data/org
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting near dir"
find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist
- awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
- find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
+ awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
+ find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
n1_wav=$(wc -l < $near_dir/wavlist)
n2_text=$(wc -l < $near_dir/textgrid.flist)
log near file found $n1_wav wav and $n2_text text.
- paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
-
- # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
- cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+ paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
- #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
- #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@@ -97,9 +96,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
- paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
-
- cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+ paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
@@ -119,28 +116,28 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- log "stage 3: finali data process"
+ log "stage 3: final data process"
local/fix_data_dir.sh $near_dir
local/fix_data_dir.sh $far_dir
- local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
- local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+ local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
+ local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
- sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
- sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+ sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+ log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@@ -150,14 +147,15 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
- local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+ local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
+ rm -rf data/local
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
diff --git a/egs/alimeeting/sa-asr/local/apply_map.pl b/egs/alimeeting/sa_asr/local/apply_map.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/apply_map.pl
rename to egs/alimeeting/sa_asr/local/apply_map.pl
diff --git a/egs/alimeeting/sa-asr/local/combine_data.sh b/egs/alimeeting/sa_asr/local/combine_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/combine_data.sh
rename to egs/alimeeting/sa_asr/local/combine_data.sh
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.py b/egs/alimeeting/sa_asr/local/compute_cmvn.py
new file mode 100755
index 000000000..d16563a96
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.py
@@ -0,0 +1,134 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+import yaml
+from funasr.models.frontend.default import DefaultFrontend
+import torch
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="computer global cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dimension",
+ )
+ parser.add_argument(
+ "--wav_path",
+ default=False,
+ required=True,
+ type=str,
+ help="the path of wav scps",
+ )
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ help="the config file for computing cmvn",
+ )
+ parser.add_argument(
+ "--idx",
+ default=1,
+ required=True,
+ type=int,
+ help="index",
+ )
+ return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
+
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
+ total_frames = 0
+
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+
+ with open(args.config_file) as f:
+ configs = yaml.safe_load(f)
+ frontend_configs = configs.get("frontend_conf", {})
+ num_mel_bins = frontend_configs.get("n_mels", 80)
+ frame_length = frontend_configs.get("frame_length", 25)
+ frame_shift = frontend_configs.get("frame_shift", 10)
+ window_type = frontend_configs.get("window", "hamming")
+ resample_rate = frontend_configs.get("fs", 16000)
+ n_fft = frontend_configs.get("n_fft", "400")
+ use_channel = frontend_configs.get("use_channel", None)
+ assert num_mel_bins == args.dim
+ frontend = DefaultFrontend(
+ fs=resample_rate,
+ n_fft=n_fft,
+ win_length=frame_length * 16,
+ hop_length=frame_shift * 16,
+ window=window_type,
+ n_mels=num_mel_bins,
+ use_channel=use_channel,
+ )
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ wavform, _ = torchaudio.load(wav_file)
+ fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
+ fbank = fbank.squeeze(0).numpy()
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
+
+ cmvn_info = {
+ 'mean_stats': list(mean_stats.tolist()),
+ 'var_stats': list(var_stats.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.sh b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
new file mode 100755
index 000000000..00d08d14c
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+fbankdir=
+nj=32
+cmd=./utils/run.pl
+feats_dim=80
+config_file=
+scale=1.0
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# shellcheck disable=SC2046
+head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
+
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
+
+logdir=${fbankdir}/cmvn/log
+$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
+ python local/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --config_file $config_file \
+ --idx JOB \
+
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
+
+echo "$0: Succeeded compute global cmvn"
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa_asr/local/compute_cpcer.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/compute_cpcer.py
rename to egs/alimeeting/sa_asr/local/compute_cpcer.py
diff --git a/egs/alimeeting/sa_asr/local/convert_model.py b/egs/alimeeting/sa_asr/local/convert_model.py
new file mode 100644
index 000000000..f0f7997fc
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/convert_model.py
@@ -0,0 +1,29 @@
+import codecs
+import pdb
+import sys
+import torch
+
+char1 = sys.argv[1]
+char2 = sys.argv[2]
+model1 = torch.load(sys.argv[3], map_location='cpu')
+model2_path = sys.argv[4]
+
+d_new = model1
+char1_list = []
+map_list = []
+
+
+with codecs.open(char1) as f:
+ for line in f.readlines():
+ char1_list.append(line.strip())
+
+with codecs.open(char2) as f:
+ for line in f.readlines():
+ map_list.append(char1_list.index(line.strip()))
+print(map_list)
+
+for k, v in d_new.items():
+ if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
+ d_new[k] = v[map_list]
+
+torch.save(d_new, model2_path)
diff --git a/egs/alimeeting/sa-asr/local/copy_data_dir.sh b/egs/alimeeting/sa_asr/local/copy_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/copy_data_dir.sh
rename to egs/alimeeting/sa_asr/local/copy_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh b/egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh b/egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
rename to egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh b/egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/split_data.sh b/egs/alimeeting/sa_asr/local/data/split_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/split_data.sh
rename to egs/alimeeting/sa_asr/local/data/split_data.sh
diff --git a/egs/alimeeting/sa_asr/local/download_and_untar.sh b/egs/alimeeting/sa_asr/local/download_and_untar.sh
new file mode 100755
index 000000000..d98255915
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] "
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo " can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py b/egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py
rename to egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa_asr/local/download_xvector_model.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_xvector_model.py
rename to egs/alimeeting/sa_asr/local/download_xvector_model.py
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
rename to egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
diff --git a/egs/alimeeting/sa-asr/local/fix_data_dir.sh b/egs/alimeeting/sa_asr/local/fix_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/fix_data_dir.sh
rename to egs/alimeeting/sa_asr/local/fix_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa_asr/local/format_wav_scp.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.py
rename to egs/alimeeting/sa_asr/local/format_wav_scp.py
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa_asr/local/format_wav_scp.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.sh
rename to egs/alimeeting/sa_asr/local/format_wav_scp.sh
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
similarity index 97%
rename from egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
rename to egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
index c37abf9a0..859b72fce 100644
--- a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
+++ b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
@@ -63,7 +63,7 @@ if __name__ == "__main__":
wav_scp_file = open(path+'/wav.scp', 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -92,8 +92,8 @@ if __name__ == "__main__":
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
meeting_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
index 18286b42d..2a99b2b6b 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
@@ -9,7 +9,7 @@ import soundfile
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -22,8 +22,8 @@ if __name__=="__main__":
raw_wav_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
raw_wav_map[meeting] = wav_path
spk_map = {}
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
similarity index 96%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
index 186f1de9f..ff65a1f90 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
@@ -5,7 +5,7 @@ import sys
if __name__=="__main__":
- path = sys.argv[1] # dump2/raw/Train_Ali_far
+ path = sys.argv[1]
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
@@ -29,7 +29,7 @@ if __name__=="__main__":
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
- spk = meeting+'_' + spk_id
+ spk = meeting + '_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
rename to egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
rename to egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
index d900bb17a..488344fb0 100755
--- a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
+++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
@@ -30,8 +30,7 @@ def main(args):
meetingid_map = {}
for line in spk2utt:
spkid = line.strip().split(" ")[0]
- meeting_id_list = spkid.split("_")[:3]
- meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+ meeting_id = spkid.split("-")[0]
if meeting_id not in meetingid_map:
meetingid_map[meeting_id] = 1
else:
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa_asr/local/process_text_id.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_id.py
rename to egs/alimeeting/sa_asr/local/process_text_id.py
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa_asr/local/process_text_spk_merge.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_spk_merge.py
rename to egs/alimeeting/sa_asr/local/process_text_spk_merge.py
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
rename to egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
diff --git a/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
rename to egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa_asr/local/text_format.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_format.pl
rename to egs/alimeeting/sa_asr/local/text_format.pl
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa_asr/local/text_normalize.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_normalize.pl
rename to egs/alimeeting/sa_asr/local/text_normalize.pl
diff --git a/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
rename to egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
diff --git a/egs/alimeeting/sa-asr/local/validate_data_dir.sh b/egs/alimeeting/sa_asr/local/validate_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_data_dir.sh
rename to egs/alimeeting/sa_asr/local/validate_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/validate_text.pl b/egs/alimeeting/sa_asr/local/validate_text.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_text.pl
rename to egs/alimeeting/sa_asr/local/validate_text.pl
diff --git a/egs/alimeeting/sa_asr/path.sh b/egs/alimeeting/sa_asr/path.sh
new file mode 100755
index 000000000..83ae507b7
--- /dev/null
+++ b/egs/alimeeting/sa_asr/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
+export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
diff --git a/egs/alimeeting/sa_asr/run.sh b/egs/alimeeting/sa_asr/run.sh
new file mode 100755
index 000000000..43d0da13f
--- /dev/null
+++ b/egs/alimeeting/sa_asr/run.sh
@@ -0,0 +1,435 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="6,7"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=8
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="data" #feature output dictionary
+exp_dir="exp"
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="1.0"
+min_wav_duration=0.1
+max_wav_duration=20
+profile_modes="cluster oracle"
+stage=7
+stop_stage=7
+
+# feature configuration
+feats_dim=80
+nj=32
+
+# data
+raw_data=
+data_url=
+
+# exp tag
+tag=""
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far Eval_Ali_far"
+test_2023="Test_2023_Ali_far_release"
+
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+inference_config=conf/decode_asr_rnn.yaml
+inference_sa_asr_model=valid.acc_spk.ave.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ # remove long/short data
+ for x in ${train_set} ${valid_set} ${test_sets}; do
+ cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
+ rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
+ <"${feats_dir}/${x}/utt2num_samples" \
+ awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
+ >"${feats_dir}/${x}/utt2num_samples_rmls"
+ mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
+ <"${feats_dir}/${x}/wav.scp" \
+ utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \
+ >"${feats_dir}/${x}/wav.scp_rmls"
+ mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
+ <"${feats_dir}/${x}/text" \
+ awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
+ mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
+ local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
+ <"${feats_dir}/${x}/utt2spk_all_fifo" \
+ utils/filter_scp.pl "${feats_dir}/${x}/text" \
+ >"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
+ mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
+ done
+ for x in ${test_2023}; do
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
+ paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
+ cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Speaker profile and CMVN Generation"
+
+ mkdir -p "profile_log"
+ for x in "${train_set}" "${valid_set}" "${test_sets}"; do
+ # generate text_id spk2id
+ python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
+ # generate text_id_train for sot
+ python local/process_text_id.py ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id_train"
+ # generate oracle_embedding from single-speaker audio segment
+ echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
+ python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
+ echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
+ # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+ if [ "${x}" = "${train_set}" ]; then
+ python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
+ else
+ python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
+ fi
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
+ echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ fi
+ # compute CMVN
+ if [ "${x}" = "${train_set}" ]; then
+ local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+ fi
+ done
+
+ for x in "${test_2023}"; do
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ done
+fi
+
+token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "" > ${token_list}
+ echo "" >> ${token_list}
+ echo "" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ echo "" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: ASR Training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ mkdir -p ${asr_exp}
+ mkdir -p ${asr_exp}/log
+ INIT_FILE=${asr_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --token_type char \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/${asr_model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+fi
+
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "SA-ASR training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ mkdir -p ${sa_asr_exp}
+ mkdir -p ${sa_asr_exp}/log
+ INIT_FILE=${sa_asr_exp}/ddp_init
+ if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
+ ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
+ ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
+ fi
+
+ if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+ # download xvector extractor model file
+ python local/download_xvector_model.py ${exp_dir}
+ echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+ fi
+
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model sa_asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --unused_parameters true \
+ --token_type char \
+ --resume true \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text,profile.scp,text_id_train" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
+ --output_dir ${exp_dir}/${sa_asr_model_dir} \
+ --config $sa_asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "stage 6: Inference test sets"
+ for x in ${test_sets}; do
+ for profile_mode in ${profile_modes}; do
+ echo "decoding ${x} with ${profile_mode} profile"
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if [ $profile_mode = "oracle" ]; then
+ profile_scp=${profile_mode}_profile_nopadding.scp
+ else
+ profile_scp=${profile_mode}_profile_infer.scp
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
+ sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
+ python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+ done
+ done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "stage 7: Inference test 2023"
+ for x in ${test_2023}; do
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+
+ python local/process_text_spk_merge.py ${_dir}
+
+ done
+fi
+
+
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa_asr/utils
similarity index 100%
rename from egs/alimeeting/sa-asr/utils
rename to egs/alimeeting/sa_asr/utils
diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa_asr_deprecated/README.md
similarity index 100%
rename from egs/alimeeting/sa-asr/README.md
rename to egs/alimeeting/sa_asr_deprecated/README.md
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa_asr_deprecated/asr_local.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/asr_local.sh
rename to egs/alimeeting/sa_asr_deprecated/asr_local.sh
diff --git a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
rename to egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
diff --git a/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
new file mode 100644
index 000000000..88fdbc20b
--- /dev/null
+++ b/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.6
+lm_weight: 0.3
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
rename to egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
rename to egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
diff --git a/egs/alimeeting/sa_asr_deprecated/local b/egs/alimeeting/sa_asr_deprecated/local
new file mode 120000
index 000000000..2ef6217d9
--- /dev/null
+++ b/egs/alimeeting/sa_asr_deprecated/local
@@ -0,0 +1 @@
+../sa_asr/local/
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa_asr_deprecated/path.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/path.sh
rename to egs/alimeeting/sa_asr_deprecated/path.sh
diff --git a/egs/alimeeting/sa-asr/run.sh b/egs/alimeeting/sa_asr_deprecated/run.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/run.sh
rename to egs/alimeeting/sa_asr_deprecated/run.sh
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
rename to egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
diff --git a/egs/alimeeting/sa_asr_deprecated/utils b/egs/alimeeting/sa_asr_deprecated/utils
new file mode 120000
index 000000000..fe070dd3a
--- /dev/null
+++ b/egs/alimeeting/sa_asr_deprecated/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 140b4245c..c722ebc3a 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1636,8 +1636,10 @@ class Speech2TextSAASR:
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend == 'wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ from funasr.tasks.sa_asr import frontend_choices
+ if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 367b9a815..656a9657a 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -619,7 +619,12 @@ def inference_paraformer_vad_punc(
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
- batch_size_token_ms = batch_size_token * 60
+
+ batch_size_token_ms = batch_size_token*60
+ if speech2text.device == "cpu":
+ batch_size_token_ms = 0
+ batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index f4fc0a7e6..1dc3fb523 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -301,7 +301,7 @@ def get_parser():
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index d4a954cb6..200395d28 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -6,7 +6,6 @@ from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
@@ -20,17 +19,23 @@ from funasr.models.decoder.transformer_decoder import (
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
- ContextualParaformer
+
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+
+from funasr.models.e2e_sa_asr import SAASRModel
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -93,6 +98,8 @@ model_choices = ClassChoices(
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
+
),
default="asr",
)
@@ -110,6 +117,27 @@ encoder_choices = ClassChoices(
),
default="rnn",
)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
encoder_choices2 = ClassChoices(
"encoder2",
classes=dict(
@@ -134,6 +162,7 @@ decoder_choices = ClassChoices(
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
default="rnn",
)
@@ -225,6 +254,10 @@ class_choices_list = [
rnnt_decoder_choices,
# --joint_network and --joint_network_conf
joint_network_choices,
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
]
@@ -247,7 +280,7 @@ def build_asr_model(args):
# frontend
if hasattr(args, "input_size") and args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend':
+ if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
@@ -425,6 +458,33 @@ def build_asr_model(args):
joint_network=joint_network,
**args.model_conf,
)
+ elif args.model == "sa_asr":
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 830460721..e209d5192 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -40,7 +40,7 @@ else:
yield
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -51,10 +51,8 @@ class ESPnetASRModel(FunASRModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
@@ -89,8 +87,6 @@ class ESPnetASRModel(FunASRModel):
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
@@ -293,10 +289,6 @@ class ESPnetASRModel(FunASRModel):
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +309,6 @@ class ESPnetASRModel(FunASRModel):
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -337,7 +324,7 @@ class ESPnetASRModel(FunASRModel):
)
if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
+ return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
return encoder_out, encoder_out_lens, encoder_out_spk
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 19994f0e6..6718f3f6c 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -2,7 +2,7 @@ import copy
from typing import Optional
from typing import Tuple
from typing import Union
-
+import logging
import humanfriendly
import numpy as np
import torch
@@ -14,6 +14,7 @@ from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.nets_utils import make_pad_mask
class DefaultFrontend(AbsFrontend):
@@ -137,8 +138,6 @@ class DefaultFrontend(AbsFrontend):
return input_stft, feats_lens
-
-
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -147,9 +146,9 @@ class MultiChannelFrontend(AbsFrontend):
def __init__(
self,
fs: Union[int, str] = 16000,
- n_fft: int = 512,
- win_length: int = None,
- hop_length: int = 128,
+ n_fft: int = 400,
+ frame_length: int = 25,
+ frame_shift: int = 10,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
@@ -160,10 +159,10 @@ class MultiChannelFrontend(AbsFrontend):
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
- frame_length: int = None,
- frame_shift: int = None,
- lfr_m: int = None,
- lfr_n: int = None,
+ use_channel: int = None,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ cmvn_file: str = None
):
assert check_argument_types()
super().__init__()
@@ -172,13 +171,14 @@ class MultiChannelFrontend(AbsFrontend):
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
- self.hop_length = hop_length
+ self.win_length = frame_length * 16
+ self.hop_length = frame_shift * 16
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
center=center,
window=window,
normalized=normalized,
@@ -202,7 +202,17 @@ class MultiChannelFrontend(AbsFrontend):
htk=htk,
)
self.n_mels = n_mels
- self.frontend_type = "multichannelfrontend"
+ self.frontend_type = "default"
+ self.use_channel = use_channel
+ if self.use_channel is not None:
+ logging.info("use the channel %d" % (self.use_channel))
+ else:
+ logging.info("random select channel")
+ self.cmvn_file = cmvn_file
+ if self.cmvn_file is not None:
+ mean, std = self._load_cmvn(self.cmvn_file)
+ self.register_buffer("mean", torch.from_numpy(mean))
+ self.register_buffer("std", torch.from_numpy(std))
def output_size(self) -> int:
return self.n_mels
@@ -215,16 +225,29 @@ class MultiChannelFrontend(AbsFrontend):
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
- if isinstance(input, ComplexTensor):
- input_stft = input
- else:
- input_stft = ComplexTensor(input[..., 0], input[..., 1])
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
+
+ # 3. [Multi channel case]: Select a channel
+ if input_stft.dim() == 4:
+ # h: (B, T, C, F) -> h: (B, T, F)
+ if self.training:
+ if self.use_channel is not None:
+ input_stft = input_stft[:, :, self.use_channel, :]
+
+ else:
+ # Select 1ch randomly
+ ch = np.random.randint(input_stft.size(2))
+ input_stft = input_stft[:, :, ch, :]
+ else:
+ # Use the first channel
+ input_stft = input_stft[:, :, 0, :]
+
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
@@ -233,18 +256,27 @@ class MultiChannelFrontend(AbsFrontend):
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
- bt = input_feats.size(0)
- if input_feats.dim() ==4:
- channel_size = input_feats.size(2)
- # batch * channel * T * D
- #pdb.set_trace()
- input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
- # input_feats = input_feats.transpose(1,2)
- # batch * channel
- feats_lens = feats_lens.repeat(1,channel_size).squeeze()
- else:
- channel_size = 1
- return input_feats, feats_lens, channel_size
+
+ # 6. Apply CMVN
+ if self.cmvn_file is not None:
+ if feats_lens is None:
+ feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
+ self.mean = self.mean.to(input_feats.device, input_feats.dtype)
+ self.std = self.std.to(input_feats.device, input_feats.dtype)
+ mask = make_pad_mask(feats_lens, input_feats, 1)
+
+ if input_feats.requires_grad:
+ input_feats = input_feats + self.mean
+ else:
+ input_feats += self.mean
+ if input_feats.requires_grad:
+ input_feats = input_feats.masked_fill(mask, 0.0)
+ else:
+ input_feats.masked_fill_(mask, 0.0)
+
+ input_feats *= self.std
+
+ return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -258,4 +290,27 @@ class MultiChannelFrontend(AbsFrontend):
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
\ No newline at end of file
+ return input_stft, feats_lens
+
+ def _load_cmvn(self, cmvn_file):
+ with open(cmvn_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ means_list = []
+ vars_list = []
+ for i in range(len(lines)):
+ line_item = lines[i].split()
+ if line_item[0] == '':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '':
+ add_shift_line = line_item[3:(len(line_item) - 1)]
+ means_list = list(add_shift_line)
+ continue
+ elif line_item[0] == '':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '':
+ rescale_line = line_item[3:(len(line_item) - 1)]
+ vars_list = list(rescale_line)
+ continue
+ means = np.array(means_list).astype(np.float)
+ vars = np.array(vars_list).astype(np.float)
+ return means, vars
\ No newline at end of file
diff --git a/funasr/runtime/java/FunasrWsClient.java b/funasr/runtime/java/FunasrWsClient.java
new file mode 100644
index 000000000..ec55c9425
--- /dev/null
+++ b/funasr/runtime/java/FunasrWsClient.java
@@ -0,0 +1,344 @@
+//
+// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+// Reserved. MIT License (https://opensource.org/licenses/MIT)
+//
+/*
+ * // 2022-2023 by zhaomingwork@qq.com
+ */
+// java FunasrWsClient
+// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+package websocket;
+
+import java.io.*;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.*;
+import java.util.Map;
+import net.sourceforge.argparse4j.ArgumentParsers;
+import net.sourceforge.argparse4j.inf.ArgumentParser;
+import net.sourceforge.argparse4j.inf.ArgumentParserException;
+import net.sourceforge.argparse4j.inf.Namespace;
+import org.java_websocket.client.WebSocketClient;
+import org.java_websocket.drafts.Draft;
+import org.java_websocket.handshake.ServerHandshake;
+import org.json.simple.JSONArray;
+import org.json.simple.JSONObject;
+import org.json.simple.parser.JSONParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** This example demonstrates how to connect to websocket server. */
+public class FunasrWsClient extends WebSocketClient {
+
+ public class RecWavThread extends Thread {
+ private FunasrWsClient funasrClient;
+
+ public RecWavThread(FunasrWsClient funasrClient) {
+ this.funasrClient = funasrClient;
+ }
+
+ public void run() {
+ this.funasrClient.recWav();
+ }
+ }
+
+ private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
+
+ public FunasrWsClient(URI serverUri, Draft draft) {
+ super(serverUri, draft);
+ }
+
+ public FunasrWsClient(URI serverURI) {
+ super(serverURI);
+ }
+
+ public FunasrWsClient(URI serverUri, Map httpHeaders) {
+ super(serverUri, httpHeaders);
+ }
+
+ public void getSslContext(String keyfile, String certfile) {
+ // TODO
+ return;
+ }
+
+ // send json at first time
+ public void sendJson(
+ String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
+ try {
+
+ JSONObject obj = new JSONObject();
+ obj.put("mode", mode);
+ JSONArray array = new JSONArray();
+ String[] chunkList = strChunkSize.split(",");
+ for (int i = 0; i < chunkList.length; i++) {
+ array.add(Integer.valueOf(chunkList[i].trim()));
+ }
+
+ obj.put("chunk_size", array);
+ obj.put("chunk_interval", new Integer(chunkInterval));
+ obj.put("wav_name", wavName);
+ if (isSpeaking) {
+ obj.put("is_speaking", new Boolean(true));
+ } else {
+ obj.put("is_speaking", new Boolean(false));
+ }
+ logger.info("sendJson: " + obj);
+ // return;
+
+ send(obj.toString());
+
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // send json at end of wav
+ public void sendEof() {
+ try {
+ JSONObject obj = new JSONObject();
+
+ obj.put("is_speaking", new Boolean(false));
+
+ logger.info("sendEof: " + obj);
+ // return;
+
+ send(obj.toString());
+ iseof = true;
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // function for rec wav file
+ public void recWav() {
+ sendJson(mode, strChunkSize, chunkInterval, wavName, true);
+ File file = new File(FunasrWsClient.wavPath);
+
+ int chunkSize = sendChunkSize;
+ byte[] bytes = new byte[chunkSize];
+
+ int readSize = 0;
+ try (FileInputStream fis = new FileInputStream(file)) {
+ if (FunasrWsClient.wavPath.endsWith(".wav")) {
+ fis.read(bytes, 0, 44); //skip first 44 wav header
+ }
+ readSize = fis.read(bytes, 0, chunkSize);
+ while (readSize > 0) {
+ // send when it is chunk size
+ if (readSize == chunkSize) {
+ send(bytes); // send buf to server
+
+ } else {
+ // send when at last or not is chunk size
+ byte[] tmpBytes = new byte[readSize];
+ for (int i = 0; i < readSize; i++) {
+ tmpBytes[i] = bytes[i];
+ }
+ send(tmpBytes);
+ }
+ // if not in offline mode, we simulate online stream by sleep
+ if (!mode.equals("offline")) {
+ Thread.sleep(Integer.valueOf(chunkSize / 32));
+ }
+
+ readSize = fis.read(bytes, 0, chunkSize);
+ }
+
+ if (!mode.equals("offline")) {
+ // if not offline, we send eof and wait for 3 seconds to close
+ Thread.sleep(2000);
+ sendEof();
+ Thread.sleep(3000);
+ close();
+ } else {
+ // if offline, just send eof
+ sendEof();
+ }
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onOpen(ServerHandshake handshakedata) {
+
+ RecWavThread thread = new RecWavThread(this);
+ thread.start();
+ }
+
+ @Override
+ public void onMessage(String message) {
+ JSONObject jsonObject = new JSONObject();
+ JSONParser jsonParser = new JSONParser();
+ logger.info("received: " + message);
+ try {
+ jsonObject = (JSONObject) jsonParser.parse(message);
+ logger.info("text: " + jsonObject.get("text"));
+ } catch (org.json.simple.parser.ParseException e) {
+ e.printStackTrace();
+ }
+ if (iseof && mode.equals("offline")) {
+ close();
+ }
+ }
+
+ @Override
+ public void onClose(int code, String reason, boolean remote) {
+
+ logger.info(
+ "Connection closed by "
+ + (remote ? "remote peer" : "us")
+ + " Code: "
+ + code
+ + " Reason: "
+ + reason);
+ }
+
+ @Override
+ public void onError(Exception ex) {
+ logger.info("ex: " + ex);
+ ex.printStackTrace();
+ // if the error is fatal then onClose will be called additionally
+ }
+
+ private boolean iseof = false;
+ public static String wavPath;
+ static String mode = "online";
+ static String strChunkSize = "5,10,5";
+ static int chunkInterval = 10;
+ static int sendChunkSize = 1920;
+
+ String wavName = "javatest";
+
+ public static void main(String[] args) throws URISyntaxException {
+ ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
+ parser
+ .addArgument("--port")
+ .help("Port on which to listen.")
+ .setDefault("8889")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--host")
+ .help("the IP address of server.")
+ .setDefault("127.0.0.1")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--audio_in")
+ .help("wav path for decoding.")
+ .setDefault("asr_example.wav")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--num_threads")
+ .help("num of threads for test.")
+ .setDefault(1)
+ .type(Integer.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_size")
+ .help("chunk size for asr.")
+ .setDefault("5, 10, 5")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_interval")
+ .help("chunk for asr.")
+ .setDefault(10)
+ .type(Integer.class)
+ .required(false);
+
+ parser
+ .addArgument("--mode")
+ .help("mode for asr.")
+ .setDefault("offline")
+ .type(String.class)
+ .required(false);
+ String srvIp = "";
+ String srvPort = "";
+ String wavPath = "";
+ int numThreads = 1;
+ String chunk_size = "";
+ int chunk_interval = 10;
+ String strmode = "offline";
+
+ try {
+ Namespace ns = parser.parseArgs(args);
+ srvIp = ns.get("host");
+ srvPort = ns.get("port");
+ wavPath = ns.get("audio_in");
+ numThreads = ns.get("num_threads");
+ chunk_size = ns.get("chunk_size");
+ chunk_interval = ns.get("chunk_interval");
+ strmode = ns.get("mode");
+ System.out.println(srvPort);
+
+ } catch (ArgumentParserException ex) {
+ ex.getParser().handleError(ex);
+ return;
+ }
+
+ FunasrWsClient.strChunkSize = chunk_size;
+ FunasrWsClient.chunkInterval = chunk_interval;
+ FunasrWsClient.wavPath = wavPath;
+ FunasrWsClient.mode = strmode;
+ System.out.println(
+ "serIp="
+ + srvIp
+ + ",srvPort="
+ + srvPort
+ + ",wavPath="
+ + wavPath
+ + ",strChunkSize"
+ + strChunkSize);
+
+ class ClientThread implements Runnable {
+
+ String srvIp;
+ String srvPort;
+
+ ClientThread(String srvIp, String srvPort, String wavPath) {
+ this.srvIp = srvIp;
+ this.srvPort = srvPort;
+ }
+
+ public void run() {
+ try {
+
+ int RATE = 16000;
+ String[] chunkList = strChunkSize.split(",");
+ int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
+ int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
+ int stride =
+ Integer.valueOf(
+ 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
+ System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
+ System.out.println("CHUNK:" + CHUNK);
+ System.out.println("stride:" + String.valueOf(stride));
+ FunasrWsClient.sendChunkSize = CHUNK * 2;
+
+ String wsAddress = "ws://" + srvIp + ":" + srvPort;
+
+ FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
+
+ c.connect();
+
+ System.out.println("wsAddress:" + wsAddress);
+ } catch (Exception e) {
+ e.printStackTrace();
+ System.out.println("e:" + e);
+ }
+ }
+ }
+ for (int i = 0; i < numThreads; i++) {
+ System.out.println("Thread1 is running...");
+ Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
+ t.start();
+ }
+ }
+}
diff --git a/funasr/runtime/java/Makefile b/funasr/runtime/java/Makefile
new file mode 100644
index 000000000..9a70ca5fc
--- /dev/null
+++ b/funasr/runtime/java/Makefile
@@ -0,0 +1,76 @@
+
+ENTRY_POINT = ./
+
+
+
+
+WEBSOCKET_DIR:= ./
+WEBSOCKET_FILES = \
+ $(WEBSOCKET_DIR)/FunasrWsClient.java \
+
+
+
+LIB_BUILD_DIR = ./lib
+
+
+
+
+JAVAC = javac
+
+BUILD_DIR = build
+
+
+RUNJFLAGS = -Dfile.encoding=utf-8
+
+
+vpath %.class $(BUILD_DIR)
+vpath %.java src
+
+
+
+
+rebuild: clean all
+
+.PHONY: clean run downjar
+
+downjar:
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
+ #wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/
+ wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/
+ rm -frv build
+ mkdir build
+clean:
+ rm -frv $(BUILD_DIR)/*
+ rm -frv $(LIB_BUILD_DIR)/*
+ mkdir -p $(BUILD_DIR)
+ mkdir -p ./lib
+
+
+
+
+
+
+runclient:
+ java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+
+
+buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
+
+
+%.class: %.java
+
+ $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $<
+
+packjar:
+ jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) .
+
+all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
+
+
+
+
+
diff --git a/funasr/runtime/java/readme.md b/funasr/runtime/java/readme.md
new file mode 100644
index 000000000..406a21a0a
--- /dev/null
+++ b/funasr/runtime/java/readme.md
@@ -0,0 +1,66 @@
+# Client for java websocket example
+
+
+
+## Building for Linux/Unix
+
+### install java environment
+```shell
+# in ubuntu
+apt-get install openjdk-11-jdk
+```
+
+
+
+### Build and run by make
+
+
+```shell
+cd funasr/runtime/java
+# download java lib
+make downjar
+# compile
+make buildwebsocket
+# run client
+make runclient
+
+```
+
+## Run java websocket client by shell
+
+```shell
+# full command refer to Makefile runclient
+usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+ [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+
+Where:
+ --host
+ (required) server-ip
+
+ --port
+ (required) port
+
+ --audio_in
+ (required) the wav or pcm file path
+
+ --num_threads
+ thread number for test
+
+ --mode
+ asr mode, support "offline" "online" "2pass"
+
+
+
+example:
+FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+result json, example like:
+{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"javatest"}
+```
+
+
+## Acknowledge
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example.
+
+
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index d2692ceb5..a4ee7f7df 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -65,7 +65,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector wav_list, vector wa
n_total_length += snippet_time;
FunASRFreeResult(result);
}else{
- LOG(ERROR) << ("No return data!\n");
+ LOG(ERROR) << wav_ids[i] << (": No return data!\n");
}
}
{
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index 58eec2540..2ee41140f 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -18,6 +18,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p
try{
m_session = std::make_unique(env_, punc_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << punc_model;
}
catch (std::exception const &e) {
LOG(ERROR) << "Error when load punc onnx model: " << e.what();
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 1957a12d4..b605ffff6 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -33,6 +33,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
try {
m_session = std::make_unique(env_, am_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
exit(0);
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 64e363f81..246d67895 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@ def get_readme():
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.1.0'
+VERSION_NUM = '0.1.1'
setuptools.setup(
name=MODULE_NAME,
@@ -31,7 +31,7 @@ setuptools.setup(
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
- "typeguard",
+ "typeguard==2.13.3",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"funasr",
diff --git a/funasr/runtime/ssl_key/readme.md b/funasr/runtime/ssl_key/readme.md
index a5989e6ec..8a48dd3e2 100644
--- a/funasr/runtime/ssl_key/readme.md
+++ b/funasr/runtime/ssl_key/readme.md
@@ -3,7 +3,7 @@ generated certificate may not suitable for all browsers due to security concerns
```shell
### 1) Generate a private key
-openssl genrsa -des3 -out server.key 1024
+openssl genrsa -des3 -out server.key 2048
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
@@ -14,4 +14,4 @@ openssl rsa -in server.key.org -out server.key
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
-```
\ No newline at end of file
+```
diff --git a/funasr/runtime/ssl_key/server.crt b/funasr/runtime/ssl_key/server.crt
index 808b73e6e..5a5079d08 100644
--- a/funasr/runtime/ssl_key/server.crt
+++ b/funasr/runtime/ssl_key/server.crt
@@ -1,15 +1,21 @@
-----BEGIN CERTIFICATE-----
-MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx
-CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n
-MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh
-YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD
-TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH
-YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG
-SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS
-NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc
-38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL
-mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA
-dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB
-uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS
-hXtozVCgkSJzX6uD
+MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
+CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
+MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
+bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
+DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
+EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
+aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
+ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
+qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
+f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
+EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
+Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
+LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
+AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
+D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
+dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
+Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
+7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
+Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
-----END CERTIFICATE-----
diff --git a/funasr/runtime/ssl_key/server.key b/funasr/runtime/ssl_key/server.key
index aac8b2646..8efdcb832 100644
--- a/funasr/runtime/ssl_key/server.key
+++ b/funasr/runtime/ssl_key/server.key
@@ -1,15 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
-MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u
-cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso
-OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB
-AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq
-pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc
-2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE
-NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD
-M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3
-7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt
-w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z
-fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l
-SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO
-uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB
+MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
+vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
+5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
+jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
+ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
+NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
+MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
+vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
+jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
+eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
+C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
+6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
+jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
+qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
+aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
+zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
+xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
+9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
+3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
+P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
+7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
+cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
+dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
+cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
+W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
-----END RSA PRIVATE KEY-----
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 4a3c7516d..8b59000fc 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -91,7 +91,6 @@ class WebsocketClient {
using websocketpp::lib::placeholders::_1;
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
- // m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
m_client.set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
@@ -218,7 +217,7 @@ class WebsocketClient {
}
}
if (wait) {
- LOG(INFO) << "wait.." << m_open;
+ // LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
@@ -292,7 +291,7 @@ int main(int argc, char* argv[]) {
false, 1, "int");
TCLAP::ValueArg is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
- false, 0, "int");
+ false, 1, "int");
cmd.add(server_ip_);
cmd.add(port_);
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 824485631..7338513cb 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,6 +38,7 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -45,6 +46,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, Paraf
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -54,6 +56,7 @@ from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -134,6 +137,7 @@ model_choices = ClassChoices(
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
),
type_check=FunASRModel,
default="asr",
@@ -175,6 +179,27 @@ encoder_choices2 = ClassChoices(
type_check=AbsEncoder,
default="rnn",
)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
@@ -197,6 +222,7 @@ decoder_choices = ClassChoices(
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -329,6 +355,12 @@ class ASRTask(AbsTask):
default=True,
help="whether to split text using ",
)
+ group.add_argument(
+ "--max_spk_num",
+ type=int_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
group.add_argument(
"--seg_dict_file",
type=str,
@@ -1495,3 +1527,123 @@ class ASRTransducerTask(ASRTask):
#assert check_return_type(model)
return model
+
+
+class ASRTaskSAASR(ASRTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 5. Encoder
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ # import ipdb;ipdb.set_trace()
+ # 9. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index 4769758b3..957948336 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@@ -120,7 +120,7 @@ normalize_choices = ClassChoices(
model_choices = ClassChoices(
"model",
classes=dict(
- asr=ESPnetASRModel,
+ asr=SAASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
@@ -620,4 +620,4 @@ class ASRTask(AbsTask):
initialize(model, args.init)
assert check_return_type(model)
- return model
+ return model
\ No newline at end of file
diff --git a/funasr/version.txt b/funasr/version.txt
index ee6cdce3c..b61604874 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.6.1
+0.6.2