mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' into dev_wjm_infer
This commit is contained in:
commit
4e0fcee2a9
@ -72,8 +72,8 @@ If you have any questions about FunASR, please contact us by
|
||||
|
||||
## Contributors
|
||||
|
||||
| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/DeepScience.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
|
||||
|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:|
|
||||
| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
|
||||
|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|
|
||||
|
||||
## Acknowledge
|
||||
|
||||
@ -82,7 +82,6 @@ If you have any questions about FunASR, please contact us by
|
||||
3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training.
|
||||
4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime.
|
||||
5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime.
|
||||
6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
|
||||
6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5.
|
||||
|
||||
## License
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
lm: transformer
|
||||
lm_conf:
|
||||
pos_enc: null
|
||||
embed_unit: 128
|
||||
att_unit: 512
|
||||
head: 8
|
||||
unit: 2048
|
||||
layer: 16
|
||||
dropout_rate: 0.1
|
||||
|
||||
# optimization related
|
||||
grad_clip: 5.0
|
||||
batch_type: numel
|
||||
batch_bins: 500000 # 4gpus * 500000
|
||||
accum_grad: 1
|
||||
max_epoch: 15 # 15epoch is enougth
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 10 # 10 is good.
|
||||
86
egs/alimeeting/sa_asr/README.md
Normal file
86
egs/alimeeting/sa_asr/README.md
Normal file
@ -0,0 +1,86 @@
|
||||
# Get Started
|
||||
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
|
||||
# Train
|
||||
First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
|
||||
After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
|
||||
```shell
|
||||
dataset
|
||||
|—— Eval_Ali_far
|
||||
|—— Eval_Ali_near
|
||||
|—— Test_Ali_far
|
||||
|—— Test_Ali_near
|
||||
|—— Train_Ali_far
|
||||
|—— Train_Ali_near
|
||||
```
|
||||
Then you can run this receipe by running:
|
||||
```shell
|
||||
bash run.sh --stage 0 --stop-stage 6
|
||||
```
|
||||
There are 8 stages in `run.sh`:
|
||||
```shell
|
||||
stage 0: Data preparation and remove the audio which is too long or too short.
|
||||
stage 1: Speaker profile and CMVN Generation.
|
||||
stage 2: Dictionary preparation.
|
||||
stage 3: LM training (not supported).
|
||||
stage 4: ASR Training.
|
||||
stage 5: SA-ASR Training.
|
||||
stage 6: Inference
|
||||
stage 7: Inference with Test_2023_Ali_far
|
||||
```
|
||||
<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
|
||||
# Infer
|
||||
1. Download the final test set and extracted
|
||||
2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
|
||||
3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`.
|
||||
4. Run the `run.sh` as follow.
|
||||
```shell
|
||||
# Prepare test_2023 set
|
||||
bash run.sh --stage 0 --stop-stage 1
|
||||
# Decode test_2023 set
|
||||
bash run.sh --stage 7 --stop-stage 7
|
||||
```
|
||||
# Format of Final Submission
|
||||
Finally, you need to submit a file called `text_spk_merge` with the following format:
|
||||
```shell
|
||||
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
|
||||
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
|
||||
...
|
||||
```
|
||||
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
|
||||
# Baseline Results
|
||||
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
|
||||
<!-- <table>
|
||||
<tr >
|
||||
<td rowspan="2"></td>
|
||||
<td colspan="2">SI-CER(%)</td>
|
||||
<td colspan="2">cpCER(%)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Eval</td>
|
||||
<td>Test</td>
|
||||
<td>Eval</td>
|
||||
<td>Test</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>oracle profile</td>
|
||||
<td>32.05</td>
|
||||
<td>32.72</td>
|
||||
<td>47.40</td>
|
||||
<td>42.92</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>cluster profile</td>
|
||||
<td>32.05</td>
|
||||
<td>32.73</td>
|
||||
<td>53.76</td>
|
||||
<td>49.37</td>
|
||||
</tr>
|
||||
</table> -->
|
||||
| |SI-CER(%) |cp-CER(%) |
|
||||
|:---------------|:------------:|----------:|
|
||||
|oracle profile |32.72 |42.92 |
|
||||
|cluster profile|32.73 |49.37 |
|
||||
|
||||
|
||||
# Reference
|
||||
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.
|
||||
102
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
Normal file
102
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
Normal file
@ -0,0 +1,102 @@
|
||||
# network architecture
|
||||
frontend: multichannelfrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hann
|
||||
n_fft: 400
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
rel_pos_type: latest
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
activation_type: swish
|
||||
macaron_style: true
|
||||
use_cnn_module: true
|
||||
cnn_module_kernel: 15
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# ctc related
|
||||
ctc_conf:
|
||||
ignore_nan_grad: true
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
dataset_conf:
|
||||
data_names: speech,text
|
||||
data_types: sound,text
|
||||
shuffle: True
|
||||
shuffle_conf:
|
||||
shuffle_size: 2048
|
||||
sort_size: 500
|
||||
batch_conf:
|
||||
batch_type: token
|
||||
batch_size: 7000
|
||||
num_workers: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
|
||||
specaug: specaug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
131
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
Normal file
131
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
Normal file
@ -0,0 +1,131 @@
|
||||
# network architecture
|
||||
frontend: multichannelfrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hann
|
||||
n_fft: 400
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
asr_encoder: conformer
|
||||
asr_encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
activation_type: swish
|
||||
macaron_style: true
|
||||
use_cnn_module: true
|
||||
cnn_module_kernel: 15
|
||||
|
||||
spk_encoder: resnet34_diar
|
||||
spk_encoder_conf:
|
||||
use_head_conv: true
|
||||
batchnorm_momentum: 0.5
|
||||
use_head_maxpool: false
|
||||
num_nodes_pooling_layer: 256
|
||||
layers_in_block:
|
||||
- 3
|
||||
- 4
|
||||
- 6
|
||||
- 3
|
||||
filters_in_block:
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
pooling_type: statistic
|
||||
num_nodes_resnet1: 256
|
||||
num_nodes_last_layer: 256
|
||||
batchnorm_momentum: 0.5
|
||||
|
||||
# decoder related
|
||||
decoder: sa_decoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
asr_num_blocks: 6
|
||||
spk_num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
spk_weight: 0.5
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
max_spk_num: 4
|
||||
|
||||
ctc_conf:
|
||||
ignore_nan_grad: true
|
||||
|
||||
# minibatch related
|
||||
dataset_conf:
|
||||
data_names: speech,text,profile,text_id
|
||||
data_types: sound,text,npy,text_int
|
||||
shuffle: True
|
||||
shuffle_conf:
|
||||
shuffle_size: 2048
|
||||
sort_size: 500
|
||||
batch_conf:
|
||||
batch_type: token
|
||||
batch_size: 7000
|
||||
num_workers: 8
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 60
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
- - valid
|
||||
- acc_spk
|
||||
- max
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0005
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 8000
|
||||
|
||||
specaug: specaug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
@ -21,6 +21,8 @@ EOF
|
||||
|
||||
SECONDS=0
|
||||
tgt=Train #Train or Eval
|
||||
min_wav_duration=0.1
|
||||
max_wav_duration=20
|
||||
|
||||
|
||||
log "$0 $*"
|
||||
@ -57,27 +59,24 @@ stage=1
|
||||
stop_stage=4
|
||||
mkdir -p $far_dir
|
||||
mkdir -p $near_dir
|
||||
mkdir -p data/org
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
log "stage 1:process alimeeting near dir"
|
||||
|
||||
find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist
|
||||
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
|
||||
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
|
||||
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
|
||||
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
|
||||
n1_wav=$(wc -l < $near_dir/wavlist)
|
||||
n2_text=$(wc -l < $near_dir/textgrid.flist)
|
||||
log near file found $n1_wav wav and $n2_text text.
|
||||
|
||||
paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
|
||||
|
||||
# cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
|
||||
cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
|
||||
paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
|
||||
|
||||
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
|
||||
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
|
||||
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
|
||||
|
||||
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
|
||||
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
|
||||
@ -97,9 +96,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
n2_text=$(wc -l < $far_dir/textgrid.flist)
|
||||
log far file found $n1_wav wav and $n2_text text.
|
||||
|
||||
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
|
||||
|
||||
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
|
||||
paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
|
||||
|
||||
python local/alimeeting_process_overlap_force.py --path $far_dir \
|
||||
--no-overlap false --mars True \
|
||||
@ -119,28 +116,28 @@ fi
|
||||
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "stage 3: finali data process"
|
||||
log "stage 3: final data process"
|
||||
local/fix_data_dir.sh $near_dir
|
||||
local/fix_data_dir.sh $far_dir
|
||||
local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
|
||||
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
|
||||
local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
|
||||
|
||||
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
cp data/org/${x}/text data/org/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
|
||||
> data/org/${x}/text
|
||||
rm data/org/${x}/text.org
|
||||
done
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
|
||||
log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
|
||||
cp -r $far_dir/* $far_single_speaker_dir
|
||||
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
|
||||
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
|
||||
@ -150,14 +147,15 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
|
||||
./local/fix_data_dir.sh $far_single_speaker_dir
|
||||
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far_single_speaker; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
cp data/org/${x}/text data/org/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
|
||||
> data/org/${x}/text
|
||||
rm data/org/${x}/text.org
|
||||
done
|
||||
rm -rf data/local
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
134
egs/alimeeting/sa_asr/local/compute_cmvn.py
Executable file
134
egs/alimeeting/sa_asr/local/compute_cmvn.py
Executable file
@ -0,0 +1,134 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import yaml
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
import torch
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="computer global cmvn",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim",
|
||||
default=80,
|
||||
type=int,
|
||||
help="feature dimension",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wav_path",
|
||||
default=False,
|
||||
required=True,
|
||||
type=str,
|
||||
help="the path of wav scps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
help="the config file for computing cmvn",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--idx",
|
||||
default=1,
|
||||
required=True,
|
||||
type=int,
|
||||
help="index",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank(wav_file,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
resample_rate=16000,
|
||||
speed=1.0,
|
||||
window_type="hamming"):
|
||||
waveform, sample_rate = torchaudio.load(wav_file)
|
||||
if resample_rate != sample_rate:
|
||||
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
|
||||
new_freq=resample_rate)(waveform)
|
||||
if speed != 1.0:
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform, resample_rate,
|
||||
[['speed', str(speed)], ['rate', str(resample_rate)]]
|
||||
)
|
||||
|
||||
waveform = waveform * (1 << 15)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
window_type=window_type,
|
||||
sample_frequency=resample_rate)
|
||||
|
||||
return mat.numpy()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
|
||||
cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
|
||||
|
||||
mean_stats = np.zeros(args.dim)
|
||||
var_stats = np.zeros(args.dim)
|
||||
total_frames = 0
|
||||
|
||||
# with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
|
||||
# for key, mat in ark_reader:
|
||||
# mean_stats += np.sum(mat, axis=0)
|
||||
# var_stats += np.sum(np.square(mat), axis=0)
|
||||
# total_frames += mat.shape[0]
|
||||
|
||||
with open(args.config_file) as f:
|
||||
configs = yaml.safe_load(f)
|
||||
frontend_configs = configs.get("frontend_conf", {})
|
||||
num_mel_bins = frontend_configs.get("n_mels", 80)
|
||||
frame_length = frontend_configs.get("frame_length", 25)
|
||||
frame_shift = frontend_configs.get("frame_shift", 10)
|
||||
window_type = frontend_configs.get("window", "hamming")
|
||||
resample_rate = frontend_configs.get("fs", 16000)
|
||||
n_fft = frontend_configs.get("n_fft", "400")
|
||||
use_channel = frontend_configs.get("use_channel", None)
|
||||
assert num_mel_bins == args.dim
|
||||
frontend = DefaultFrontend(
|
||||
fs=resample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=frame_length * 16,
|
||||
hop_length=frame_shift * 16,
|
||||
window=window_type,
|
||||
n_mels=num_mel_bins,
|
||||
use_channel=use_channel,
|
||||
)
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
_, wav_file = line.strip().split()
|
||||
wavform, _ = torchaudio.load(wav_file)
|
||||
fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
|
||||
fbank = fbank.squeeze(0).numpy()
|
||||
mean_stats += np.sum(fbank, axis=0)
|
||||
var_stats += np.sum(np.square(fbank), axis=0)
|
||||
total_frames += fbank.shape[0]
|
||||
|
||||
cmvn_info = {
|
||||
'mean_stats': list(mean_stats.tolist()),
|
||||
'var_stats': list(var_stats.tolist()),
|
||||
'total_frames': total_frames
|
||||
}
|
||||
with open(cmvn_file, 'w') as fout:
|
||||
fout.write(json.dumps(cmvn_info))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
39
egs/alimeeting/sa_asr/local/compute_cmvn.sh
Executable file
39
egs/alimeeting/sa_asr/local/compute_cmvn.sh
Executable file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
# Begin configuration section.
|
||||
fbankdir=
|
||||
nj=32
|
||||
cmd=./utils/run.pl
|
||||
feats_dim=80
|
||||
config_file=
|
||||
scale=1.0
|
||||
|
||||
echo "$0 $@"
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
# shellcheck disable=SC2046
|
||||
head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
|
||||
|
||||
split_dir=${fbankdir}/cmvn/split_${nj};
|
||||
mkdir -p $split_dir
|
||||
split_scps=""
|
||||
for n in $(seq $nj); do
|
||||
split_scps="$split_scps $split_dir/wav.$n.scp"
|
||||
done
|
||||
utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
|
||||
|
||||
logdir=${fbankdir}/cmvn/log
|
||||
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
|
||||
python local/compute_cmvn.py \
|
||||
--dim ${feats_dim} \
|
||||
--wav_path $split_dir \
|
||||
--config_file $config_file \
|
||||
--idx JOB \
|
||||
|
||||
python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
|
||||
|
||||
python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
|
||||
|
||||
echo "$0: Succeeded compute global cmvn"
|
||||
29
egs/alimeeting/sa_asr/local/convert_model.py
Normal file
29
egs/alimeeting/sa_asr/local/convert_model.py
Normal file
@ -0,0 +1,29 @@
|
||||
import codecs
|
||||
import pdb
|
||||
import sys
|
||||
import torch
|
||||
|
||||
char1 = sys.argv[1]
|
||||
char2 = sys.argv[2]
|
||||
model1 = torch.load(sys.argv[3], map_location='cpu')
|
||||
model2_path = sys.argv[4]
|
||||
|
||||
d_new = model1
|
||||
char1_list = []
|
||||
map_list = []
|
||||
|
||||
|
||||
with codecs.open(char1) as f:
|
||||
for line in f.readlines():
|
||||
char1_list.append(line.strip())
|
||||
|
||||
with codecs.open(char2) as f:
|
||||
for line in f.readlines():
|
||||
map_list.append(char1_list.index(line.strip()))
|
||||
print(map_list)
|
||||
|
||||
for k, v in d_new.items():
|
||||
if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
|
||||
d_new[k] = v[map_list]
|
||||
|
||||
torch.save(d_new, model2_path)
|
||||
105
egs/alimeeting/sa_asr/local/download_and_untar.sh
Executable file
105
egs/alimeeting/sa_asr/local/download_and_untar.sh
Executable file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
# 2017 Xingyu Na
|
||||
# Apache 2.0
|
||||
|
||||
remove_archive=false
|
||||
|
||||
if [ "$1" == --remove-archive ]; then
|
||||
remove_archive=true
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
|
||||
echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
|
||||
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
|
||||
echo "<corpus-part> can be one of: data_aishell, resource_aishell."
|
||||
fi
|
||||
|
||||
data=$1
|
||||
url=$2
|
||||
part=$3
|
||||
|
||||
if [ ! -d "$data" ]; then
|
||||
echo "$0: no such directory $data"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
part_ok=false
|
||||
list="data_aishell resource_aishell"
|
||||
for x in $list; do
|
||||
if [ "$part" == $x ]; then part_ok=true; fi
|
||||
done
|
||||
if ! $part_ok; then
|
||||
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -z "$url" ]; then
|
||||
echo "$0: empty URL base."
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/$part/.complete ]; then
|
||||
echo "$0: data part $part was already successfully extracted, nothing to do."
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
# sizes of the archive files in bytes.
|
||||
sizes="15582913665 1246920"
|
||||
|
||||
if [ -f $data/$part.tgz ]; then
|
||||
size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
|
||||
size_ok=false
|
||||
for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
|
||||
if ! $size_ok; then
|
||||
echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
|
||||
echo "does not equal the size of one of the archives."
|
||||
rm $data/$part.tgz
|
||||
else
|
||||
echo "$data/$part.tgz exists and appears to be complete."
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f $data/$part.tgz ]; then
|
||||
if ! command -v wget >/dev/null; then
|
||||
echo "$0: wget is not installed."
|
||||
exit 1;
|
||||
fi
|
||||
full_url=$url/$part.tgz
|
||||
echo "$0: downloading data from $full_url. This may take some time, please be patient."
|
||||
|
||||
cd $data || exit 1
|
||||
if ! wget --no-check-certificate $full_url; then
|
||||
echo "$0: error executing wget $full_url"
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
cd $data || exit 1
|
||||
|
||||
if ! tar -xvzf $part.tgz; then
|
||||
echo "$0: error un-tarring archive $data/$part.tgz"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
touch $data/$part/.complete
|
||||
|
||||
if [ $part == "data_aishell" ]; then
|
||||
cd $data/$part/wav || exit 1
|
||||
for wav in ./*.tar.gz; do
|
||||
echo "Extracting wav from $wav"
|
||||
tar -zxf $wav && rm $wav
|
||||
done
|
||||
fi
|
||||
|
||||
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
|
||||
|
||||
if $remove_archive; then
|
||||
echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
|
||||
rm $data/$part.tgz
|
||||
fi
|
||||
|
||||
exit 0;
|
||||
@ -63,7 +63,7 @@ if __name__ == "__main__":
|
||||
wav_scp_file = open(path+'/wav.scp', 'r')
|
||||
wav_scp = wav_scp_file.readlines()
|
||||
wav_scp_file.close()
|
||||
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
|
||||
raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
|
||||
raw_meeting_scp = raw_meeting_scp_file.readlines()
|
||||
raw_meeting_scp_file.close()
|
||||
segments_scp_file = open(raw_path + '/segments', 'r')
|
||||
@ -92,8 +92,8 @@ if __name__ == "__main__":
|
||||
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
|
||||
meeting_map = {}
|
||||
for line in raw_meeting_scp:
|
||||
meeting = line.strip().split('\t')[0]
|
||||
wav_path = line.strip().split('\t')[1]
|
||||
meeting = line.strip().split(' ')[0]
|
||||
wav_path = line.strip().split(' ')[1]
|
||||
wav = soundfile.read(wav_path)[0]
|
||||
# take the first channel
|
||||
if wav.ndim == 2:
|
||||
@ -9,7 +9,7 @@ import soundfile
|
||||
if __name__=="__main__":
|
||||
path = sys.argv[1] # dump2/raw/Eval_Ali_far
|
||||
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
|
||||
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
|
||||
raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
|
||||
raw_meeting_scp = raw_meeting_scp_file.readlines()
|
||||
raw_meeting_scp_file.close()
|
||||
segments_scp_file = open(raw_path + '/segments', 'r')
|
||||
@ -22,8 +22,8 @@ if __name__=="__main__":
|
||||
|
||||
raw_wav_map = {}
|
||||
for line in raw_meeting_scp:
|
||||
meeting = line.strip().split('\t')[0]
|
||||
wav_path = line.strip().split('\t')[1]
|
||||
meeting = line.strip().split(' ')[0]
|
||||
wav_path = line.strip().split(' ')[1]
|
||||
raw_wav_map[meeting] = wav_path
|
||||
|
||||
spk_map = {}
|
||||
@ -5,7 +5,7 @@ import sys
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
path = sys.argv[1] # dump2/raw/Train_Ali_far
|
||||
path = sys.argv[1]
|
||||
wav_scp_file = open(path+"/wav.scp", 'r')
|
||||
wav_scp = wav_scp_file.readlines()
|
||||
wav_scp_file.close()
|
||||
@ -29,7 +29,7 @@ if __name__=="__main__":
|
||||
line_list = line.strip().split(' ')
|
||||
meeting = line_list[0].split('-')[0]
|
||||
spk_id = line_list[0].split('-')[-1].split('_')[-1]
|
||||
spk = meeting+'_' + spk_id
|
||||
spk = meeting + '_' + spk_id
|
||||
global_spk_list.append(spk)
|
||||
if meeting in meeting_map_tmp.keys():
|
||||
meeting_map_tmp[meeting].append(spk)
|
||||
@ -30,8 +30,7 @@ def main(args):
|
||||
meetingid_map = {}
|
||||
for line in spk2utt:
|
||||
spkid = line.strip().split(" ")[0]
|
||||
meeting_id_list = spkid.split("_")[:3]
|
||||
meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
|
||||
meeting_id = spkid.split("-")[0]
|
||||
if meeting_id not in meetingid_map:
|
||||
meetingid_map[meeting_id] = 1
|
||||
else:
|
||||
6
egs/alimeeting/sa_asr/path.sh
Executable file
6
egs/alimeeting/sa_asr/path.sh
Executable file
@ -0,0 +1,6 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
|
||||
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
|
||||
435
egs/alimeeting/sa_asr/run.sh
Executable file
435
egs/alimeeting/sa_asr/run.sh
Executable file
@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="6,7"
|
||||
gpu_num=2
|
||||
count=1
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=8
|
||||
train_cmd=utils/run.pl
|
||||
infer_cmd=utils/run.pl
|
||||
|
||||
# general configuration
|
||||
feats_dir="data" #feature output dictionary
|
||||
exp_dir="exp"
|
||||
lang=zh
|
||||
token_type=char
|
||||
type=sound
|
||||
scp=wav.scp
|
||||
speed_perturb="1.0"
|
||||
min_wav_duration=0.1
|
||||
max_wav_duration=20
|
||||
profile_modes="cluster oracle"
|
||||
stage=7
|
||||
stop_stage=7
|
||||
|
||||
# feature configuration
|
||||
feats_dim=80
|
||||
nj=32
|
||||
|
||||
# data
|
||||
raw_data=
|
||||
data_url=
|
||||
|
||||
# exp tag
|
||||
tag=""
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
train_set=Train_Ali_far
|
||||
valid_set=Eval_Ali_far
|
||||
test_sets="Test_Ali_far Eval_Ali_far"
|
||||
test_2023="Test_2023_Ali_far_release"
|
||||
|
||||
asr_config=conf/train_asr_conformer.yaml
|
||||
sa_asr_config=conf/train_sa_asr_conformer.yaml
|
||||
asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
|
||||
sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
|
||||
inference_config=conf/decode_asr_rnn.yaml
|
||||
inference_sa_asr_model=valid.acc_spk.ave.pb
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "stage 0: Data preparation"
|
||||
# Data preparation
|
||||
./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
|
||||
./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
|
||||
./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
|
||||
# remove long/short data
|
||||
for x in ${train_set} ${valid_set} ${test_sets}; do
|
||||
cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
|
||||
rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
|
||||
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format wav --segments ${feats_dir}/org/${x}/segments \
|
||||
"${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
|
||||
_min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
|
||||
_max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
|
||||
<"${feats_dir}/${x}/utt2num_samples" \
|
||||
awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
|
||||
>"${feats_dir}/${x}/utt2num_samples_rmls"
|
||||
mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
|
||||
<"${feats_dir}/${x}/wav.scp" \
|
||||
utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \
|
||||
>"${feats_dir}/${x}/wav.scp_rmls"
|
||||
mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
|
||||
<"${feats_dir}/${x}/text" \
|
||||
awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
|
||||
mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
|
||||
local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
|
||||
<"${feats_dir}/${x}/utt2spk_all_fifo" \
|
||||
utils/filter_scp.pl "${feats_dir}/${x}/text" \
|
||||
>"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
|
||||
mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
|
||||
done
|
||||
for x in ${test_2023}; do
|
||||
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format wav --segments ${feats_dir}/org/${x}/segments \
|
||||
"${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
|
||||
cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
|
||||
paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
|
||||
cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Speaker profile and CMVN Generation"
|
||||
|
||||
mkdir -p "profile_log"
|
||||
for x in "${train_set}" "${valid_set}" "${test_sets}"; do
|
||||
# generate text_id spk2id
|
||||
python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
|
||||
echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
|
||||
# generate text_id_train for sot
|
||||
python local/process_text_id.py ${feats_dir}/${x}
|
||||
echo "Successfully generate ${feats_dir}/${x}/text_id_train"
|
||||
# generate oracle_embedding from single-speaker audio segment
|
||||
echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
|
||||
python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
|
||||
echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
|
||||
# generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
|
||||
if [ "${x}" = "${train_set}" ]; then
|
||||
python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
|
||||
echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
|
||||
else
|
||||
python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
|
||||
echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
|
||||
fi
|
||||
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
|
||||
if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
|
||||
echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
|
||||
python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
|
||||
echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
|
||||
fi
|
||||
# compute CMVN
|
||||
if [ "${x}" = "${train_set}" ]; then
|
||||
local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
|
||||
fi
|
||||
done
|
||||
|
||||
for x in "${test_2023}"; do
|
||||
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
|
||||
python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
|
||||
echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
|
||||
done
|
||||
fi
|
||||
|
||||
token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
|
||||
echo "dictionary: ${token_list}"
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "stage 2: Dictionary Preparation"
|
||||
mkdir -p ${feats_dir}/${lang}_token_list/char/
|
||||
|
||||
echo "make a dictionary"
|
||||
echo "<blank>" > ${token_list}
|
||||
echo "<s>" >> ${token_list}
|
||||
echo "</s>" >> ${token_list}
|
||||
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
|
||||
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
|
||||
echo "<unk>" >> ${token_list}
|
||||
fi
|
||||
|
||||
# LM Training Stage
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: LM Training"
|
||||
fi
|
||||
|
||||
# ASR Training Stage
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Stage 4: ASR Training"
|
||||
asr_exp=${exp_dir}/${asr_model_dir}
|
||||
mkdir -p ${asr_exp}
|
||||
mkdir -p ${asr_exp}/log
|
||||
INIT_FILE=${asr_exp}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $ngpu; ++i)); do
|
||||
{
|
||||
# i=0
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name asr \
|
||||
--model asr \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor true \
|
||||
--split_with_space false \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--data_dir ${feats_dir} \
|
||||
--train_set ${train_set} \
|
||||
--valid_set ${valid_set} \
|
||||
--data_file_names "wav.scp,text" \
|
||||
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
|
||||
--speed_perturb ${speed_perturb} \
|
||||
--resume true \
|
||||
--output_dir ${exp_dir}/${asr_model_dir} \
|
||||
--config $asr_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "SA-ASR training"
|
||||
asr_exp=${exp_dir}/${asr_model_dir}
|
||||
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
|
||||
mkdir -p ${sa_asr_exp}
|
||||
mkdir -p ${sa_asr_exp}/log
|
||||
INIT_FILE=${sa_asr_exp}/ddp_init
|
||||
if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
|
||||
ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
|
||||
ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
|
||||
fi
|
||||
|
||||
if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
|
||||
# download xvector extractor model file
|
||||
python local/download_xvector_model.py ${exp_dir}
|
||||
echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
|
||||
fi
|
||||
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $ngpu; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
train.py \
|
||||
--task_name asr \
|
||||
--model sa_asr \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor true \
|
||||
--split_with_space false \
|
||||
--unused_parameters true \
|
||||
--token_type char \
|
||||
--resume true \
|
||||
--token_list $token_list \
|
||||
--data_dir ${feats_dir} \
|
||||
--train_set ${train_set} \
|
||||
--valid_set ${valid_set} \
|
||||
--data_file_names "wav.scp,text,profile.scp,text_id_train" \
|
||||
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
|
||||
--speed_perturb ${speed_perturb} \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
|
||||
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
|
||||
--init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
|
||||
--init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
|
||||
--output_dir ${exp_dir}/${sa_asr_model_dir} \
|
||||
--config $sa_asr_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
echo "stage 6: Inference test sets"
|
||||
for x in ${test_sets}; do
|
||||
for profile_mode in ${profile_modes}; do
|
||||
echo "decoding ${x} with ${profile_mode} profile"
|
||||
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
|
||||
inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
_dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
|
||||
exit 0
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${feats_dir}/${x}"
|
||||
key_file=${_data}/${scp}
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
if [ $profile_mode = "oracle" ]; then
|
||||
profile_scp=${profile_mode}_profile_nopadding.scp
|
||||
else
|
||||
profile_scp=${profile_mode}_profile_infer.scp
|
||||
fi
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--nbest 1 \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--allow_variable_data_keys true \
|
||||
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
|
||||
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
|
||||
--data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--asr_train_config "${sa_asr_exp}"/config.yaml \
|
||||
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sa_asr \
|
||||
${_opts}
|
||||
|
||||
for f in token token_int score text text_id; do
|
||||
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | sort -k1 >"${_dir}/${f}"
|
||||
fi
|
||||
done
|
||||
sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
|
||||
sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
|
||||
python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
|
||||
python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
|
||||
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
|
||||
python local/process_text_spk_merge.py ${_dir}
|
||||
python local/process_text_spk_merge.py ${_data}
|
||||
|
||||
python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
|
||||
tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
|
||||
cat ${_dir}/text.cpcer.txt
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
echo "stage 7: Inference test 2023"
|
||||
for x in ${test_2023}; do
|
||||
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
|
||||
inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
_dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
|
||||
exit 0
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${feats_dir}/${x}"
|
||||
key_file=${_data}/${scp}
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--nbest 1 \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--allow_variable_data_keys true \
|
||||
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
|
||||
--data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
|
||||
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--asr_train_config "${sa_asr_exp}"/config.yaml \
|
||||
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sa_asr \
|
||||
${_opts}
|
||||
|
||||
for f in token token_int score text text_id; do
|
||||
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | sort -k1 >"${_dir}/${f}"
|
||||
fi
|
||||
done
|
||||
|
||||
python local/process_text_spk_merge.py ${_dir}
|
||||
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
beam_size: 20
|
||||
penalty: 0.0
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
ctc_weight: 0.6
|
||||
lm_weight: 0.3
|
||||
1
egs/alimeeting/sa_asr_deprecated/local
Symbolic link
1
egs/alimeeting/sa_asr_deprecated/local
Symbolic link
@ -0,0 +1 @@
|
||||
../sa_asr/local/
|
||||
1
egs/alimeeting/sa_asr_deprecated/utils
Symbolic link
1
egs/alimeeting/sa_asr_deprecated/utils
Symbolic link
@ -0,0 +1 @@
|
||||
../../aishell/transformer/utils
|
||||
@ -1636,8 +1636,10 @@ class Speech2TextSAASR:
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend == 'wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
from funasr.tasks.sa_asr import frontend_choices
|
||||
if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
|
||||
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
|
||||
else:
|
||||
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
|
||||
@ -619,7 +619,12 @@ def inference_paraformer_vad_punc(
|
||||
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
||||
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
||||
results_sorted = []
|
||||
batch_size_token_ms = batch_size_token * 60
|
||||
|
||||
batch_size_token_ms = batch_size_token*60
|
||||
if speech2text.device == "cpu":
|
||||
batch_size_token_ms = 0
|
||||
batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||
|
||||
batch_size_token_ms_cum = 0
|
||||
beg_idx = 0
|
||||
for j, _ in enumerate(range(0, n)):
|
||||
|
||||
@ -301,7 +301,7 @@ def get_parser():
|
||||
"--freeze_param",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="*",
|
||||
action="append",
|
||||
help="Freeze parameters",
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.decoder.rnn_decoder import RNNDecoder
|
||||
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
||||
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
DynamicConvolution2DTransformerDecoder, # noqa: H301
|
||||
@ -20,17 +19,23 @@ from funasr.models.decoder.transformer_decoder import (
|
||||
)
|
||||
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
||||
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
|
||||
from funasr.models.e2e_asr import ASRModel
|
||||
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
|
||||
ContextualParaformer
|
||||
|
||||
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
||||
|
||||
from funasr.models.e2e_sa_asr import SAASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
|
||||
from funasr.models.e2e_tp import TimestampPredictor
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
@ -93,6 +98,8 @@ model_choices = ClassChoices(
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
sa_asr=SAASRModel,
|
||||
|
||||
),
|
||||
default="asr",
|
||||
)
|
||||
@ -110,6 +117,27 @@ encoder_choices = ClassChoices(
|
||||
),
|
||||
default="rnn",
|
||||
)
|
||||
asr_encoder_choices = ClassChoices(
|
||||
"asr_encoder",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
mfcca_enc=MFCCAEncoder,
|
||||
),
|
||||
default="rnn",
|
||||
)
|
||||
|
||||
spk_encoder_choices = ClassChoices(
|
||||
"spk_encoder",
|
||||
classes=dict(
|
||||
resnet34_diar=ResNet34Diar,
|
||||
),
|
||||
default="resnet34_diar",
|
||||
)
|
||||
encoder_choices2 = ClassChoices(
|
||||
"encoder2",
|
||||
classes=dict(
|
||||
@ -134,6 +162,7 @@ decoder_choices = ClassChoices(
|
||||
paraformer_decoder_sanm=ParaformerSANMDecoder,
|
||||
paraformer_decoder_san=ParaformerDecoderSAN,
|
||||
contextual_paraformer_decoder=ContextualParaformerDecoder,
|
||||
sa_decoder=SAAsrTransformerDecoder,
|
||||
),
|
||||
default="rnn",
|
||||
)
|
||||
@ -225,6 +254,10 @@ class_choices_list = [
|
||||
rnnt_decoder_choices,
|
||||
# --joint_network and --joint_network_conf
|
||||
joint_network_choices,
|
||||
# --asr_encoder and --asr_encoder_conf
|
||||
asr_encoder_choices,
|
||||
# --spk_encoder and --spk_encoder_conf
|
||||
spk_encoder_choices,
|
||||
]
|
||||
|
||||
|
||||
@ -247,7 +280,7 @@ def build_asr_model(args):
|
||||
# frontend
|
||||
if hasattr(args, "input_size") and args.input_size is None:
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
@ -425,6 +458,33 @@ def build_asr_model(args):
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
elif args.model == "sa_asr":
|
||||
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
|
||||
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
|
||||
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
|
||||
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=asr_encoder.output_size(),
|
||||
**args.decoder_conf,
|
||||
)
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
|
||||
)
|
||||
|
||||
model_class = model_choices.get_class(args.model)
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
asr_encoder=asr_encoder,
|
||||
spk_encoder=spk_encoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Not supported model: {}".format(args.model))
|
||||
|
||||
@ -40,7 +40,7 @@ else:
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRModel(FunASRModel):
|
||||
class SAASRModel(FunASRModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
@ -51,10 +51,8 @@ class ESPnetASRModel(FunASRModel):
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
asr_encoder: AbsEncoder,
|
||||
spk_encoder: torch.nn.Module,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
spk_weight: float = 0.5,
|
||||
@ -89,8 +87,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
self.asr_encoder = asr_encoder
|
||||
self.spk_encoder = spk_encoder
|
||||
|
||||
@ -293,10 +289,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
@ -317,11 +309,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
|
||||
else:
|
||||
encoder_out_spk=encoder_out_spk_ori
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
@ -337,7 +324,7 @@ class ESPnetASRModel(FunASRModel):
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
|
||||
|
||||
return encoder_out, encoder_out_lens, encoder_out_spk
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import copy
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import logging
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,6 +14,7 @@ from funasr.layers.stft import Stft
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.modules.frontends.frontend import Frontend
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class DefaultFrontend(AbsFrontend):
|
||||
@ -137,8 +138,6 @@ class DefaultFrontend(AbsFrontend):
|
||||
return input_stft, feats_lens
|
||||
|
||||
|
||||
|
||||
|
||||
class MultiChannelFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
@ -147,9 +146,9 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
n_fft: int = 400,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
@ -160,10 +159,10 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
frame_length: int = None,
|
||||
frame_shift: int = None,
|
||||
lfr_m: int = None,
|
||||
lfr_n: int = None,
|
||||
use_channel: int = None,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
cmvn_file: str = None
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
@ -172,13 +171,14 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
self.win_length = frame_length * 16
|
||||
self.hop_length = frame_shift * 16
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
@ -202,7 +202,17 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "multichannelfrontend"
|
||||
self.frontend_type = "default"
|
||||
self.use_channel = use_channel
|
||||
if self.use_channel is not None:
|
||||
logging.info("use the channel %d" % (self.use_channel))
|
||||
else:
|
||||
logging.info("random select channel")
|
||||
self.cmvn_file = cmvn_file
|
||||
if self.cmvn_file is not None:
|
||||
mean, std = self._load_cmvn(self.cmvn_file)
|
||||
self.register_buffer("mean", torch.from_numpy(mean))
|
||||
self.register_buffer("std", torch.from_numpy(std))
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
@ -215,16 +225,29 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
if isinstance(input, ComplexTensor):
|
||||
input_stft = input
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
if self.use_channel is not None:
|
||||
input_stft = input_stft[:, :, self.use_channel, :]
|
||||
|
||||
else:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
@ -233,18 +256,27 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
bt = input_feats.size(0)
|
||||
if input_feats.dim() ==4:
|
||||
channel_size = input_feats.size(2)
|
||||
# batch * channel * T * D
|
||||
#pdb.set_trace()
|
||||
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
|
||||
# input_feats = input_feats.transpose(1,2)
|
||||
# batch * channel
|
||||
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
|
||||
else:
|
||||
channel_size = 1
|
||||
return input_feats, feats_lens, channel_size
|
||||
|
||||
# 6. Apply CMVN
|
||||
if self.cmvn_file is not None:
|
||||
if feats_lens is None:
|
||||
feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
|
||||
self.mean = self.mean.to(input_feats.device, input_feats.dtype)
|
||||
self.std = self.std.to(input_feats.device, input_feats.dtype)
|
||||
mask = make_pad_mask(feats_lens, input_feats, 1)
|
||||
|
||||
if input_feats.requires_grad:
|
||||
input_feats = input_feats + self.mean
|
||||
else:
|
||||
input_feats += self.mean
|
||||
if input_feats.requires_grad:
|
||||
input_feats = input_feats.masked_fill(mask, 0.0)
|
||||
else:
|
||||
input_feats.masked_fill_(mask, 0.0)
|
||||
|
||||
input_feats *= self.std
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
@ -258,4 +290,27 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
return input_stft, feats_lens
|
||||
|
||||
def _load_cmvn(self, cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
means = np.array(means_list).astype(np.float)
|
||||
vars = np.array(vars_list).astype(np.float)
|
||||
return means, vars
|
||||
344
funasr/runtime/java/FunasrWsClient.java
Normal file
344
funasr/runtime/java/FunasrWsClient.java
Normal file
@ -0,0 +1,344 @@
|
||||
//
|
||||
// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
// Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
//
|
||||
/*
|
||||
* // 2022-2023 by zhaomingwork@qq.com
|
||||
*/
|
||||
// java FunasrWsClient
|
||||
// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
|
||||
// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
|
||||
package websocket;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.*;
|
||||
import java.util.Map;
|
||||
import net.sourceforge.argparse4j.ArgumentParsers;
|
||||
import net.sourceforge.argparse4j.inf.ArgumentParser;
|
||||
import net.sourceforge.argparse4j.inf.ArgumentParserException;
|
||||
import net.sourceforge.argparse4j.inf.Namespace;
|
||||
import org.java_websocket.client.WebSocketClient;
|
||||
import org.java_websocket.drafts.Draft;
|
||||
import org.java_websocket.handshake.ServerHandshake;
|
||||
import org.json.simple.JSONArray;
|
||||
import org.json.simple.JSONObject;
|
||||
import org.json.simple.parser.JSONParser;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** This example demonstrates how to connect to websocket server. */
|
||||
public class FunasrWsClient extends WebSocketClient {
|
||||
|
||||
public class RecWavThread extends Thread {
|
||||
private FunasrWsClient funasrClient;
|
||||
|
||||
public RecWavThread(FunasrWsClient funasrClient) {
|
||||
this.funasrClient = funasrClient;
|
||||
}
|
||||
|
||||
public void run() {
|
||||
this.funasrClient.recWav();
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
|
||||
|
||||
public FunasrWsClient(URI serverUri, Draft draft) {
|
||||
super(serverUri, draft);
|
||||
}
|
||||
|
||||
public FunasrWsClient(URI serverURI) {
|
||||
super(serverURI);
|
||||
}
|
||||
|
||||
public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
|
||||
super(serverUri, httpHeaders);
|
||||
}
|
||||
|
||||
public void getSslContext(String keyfile, String certfile) {
|
||||
// TODO
|
||||
return;
|
||||
}
|
||||
|
||||
// send json at first time
|
||||
public void sendJson(
|
||||
String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
|
||||
try {
|
||||
|
||||
JSONObject obj = new JSONObject();
|
||||
obj.put("mode", mode);
|
||||
JSONArray array = new JSONArray();
|
||||
String[] chunkList = strChunkSize.split(",");
|
||||
for (int i = 0; i < chunkList.length; i++) {
|
||||
array.add(Integer.valueOf(chunkList[i].trim()));
|
||||
}
|
||||
|
||||
obj.put("chunk_size", array);
|
||||
obj.put("chunk_interval", new Integer(chunkInterval));
|
||||
obj.put("wav_name", wavName);
|
||||
if (isSpeaking) {
|
||||
obj.put("is_speaking", new Boolean(true));
|
||||
} else {
|
||||
obj.put("is_speaking", new Boolean(false));
|
||||
}
|
||||
logger.info("sendJson: " + obj);
|
||||
// return;
|
||||
|
||||
send(obj.toString());
|
||||
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// send json at end of wav
|
||||
public void sendEof() {
|
||||
try {
|
||||
JSONObject obj = new JSONObject();
|
||||
|
||||
obj.put("is_speaking", new Boolean(false));
|
||||
|
||||
logger.info("sendEof: " + obj);
|
||||
// return;
|
||||
|
||||
send(obj.toString());
|
||||
iseof = true;
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// function for rec wav file
|
||||
public void recWav() {
|
||||
sendJson(mode, strChunkSize, chunkInterval, wavName, true);
|
||||
File file = new File(FunasrWsClient.wavPath);
|
||||
|
||||
int chunkSize = sendChunkSize;
|
||||
byte[] bytes = new byte[chunkSize];
|
||||
|
||||
int readSize = 0;
|
||||
try (FileInputStream fis = new FileInputStream(file)) {
|
||||
if (FunasrWsClient.wavPath.endsWith(".wav")) {
|
||||
fis.read(bytes, 0, 44); //skip first 44 wav header
|
||||
}
|
||||
readSize = fis.read(bytes, 0, chunkSize);
|
||||
while (readSize > 0) {
|
||||
// send when it is chunk size
|
||||
if (readSize == chunkSize) {
|
||||
send(bytes); // send buf to server
|
||||
|
||||
} else {
|
||||
// send when at last or not is chunk size
|
||||
byte[] tmpBytes = new byte[readSize];
|
||||
for (int i = 0; i < readSize; i++) {
|
||||
tmpBytes[i] = bytes[i];
|
||||
}
|
||||
send(tmpBytes);
|
||||
}
|
||||
// if not in offline mode, we simulate online stream by sleep
|
||||
if (!mode.equals("offline")) {
|
||||
Thread.sleep(Integer.valueOf(chunkSize / 32));
|
||||
}
|
||||
|
||||
readSize = fis.read(bytes, 0, chunkSize);
|
||||
}
|
||||
|
||||
if (!mode.equals("offline")) {
|
||||
// if not offline, we send eof and wait for 3 seconds to close
|
||||
Thread.sleep(2000);
|
||||
sendEof();
|
||||
Thread.sleep(3000);
|
||||
close();
|
||||
} else {
|
||||
// if offline, just send eof
|
||||
sendEof();
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(ServerHandshake handshakedata) {
|
||||
|
||||
RecWavThread thread = new RecWavThread(this);
|
||||
thread.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(String message) {
|
||||
JSONObject jsonObject = new JSONObject();
|
||||
JSONParser jsonParser = new JSONParser();
|
||||
logger.info("received: " + message);
|
||||
try {
|
||||
jsonObject = (JSONObject) jsonParser.parse(message);
|
||||
logger.info("text: " + jsonObject.get("text"));
|
||||
} catch (org.json.simple.parser.ParseException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
if (iseof && mode.equals("offline")) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(int code, String reason, boolean remote) {
|
||||
|
||||
logger.info(
|
||||
"Connection closed by "
|
||||
+ (remote ? "remote peer" : "us")
|
||||
+ " Code: "
|
||||
+ code
|
||||
+ " Reason: "
|
||||
+ reason);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Exception ex) {
|
||||
logger.info("ex: " + ex);
|
||||
ex.printStackTrace();
|
||||
// if the error is fatal then onClose will be called additionally
|
||||
}
|
||||
|
||||
private boolean iseof = false;
|
||||
public static String wavPath;
|
||||
static String mode = "online";
|
||||
static String strChunkSize = "5,10,5";
|
||||
static int chunkInterval = 10;
|
||||
static int sendChunkSize = 1920;
|
||||
|
||||
String wavName = "javatest";
|
||||
|
||||
public static void main(String[] args) throws URISyntaxException {
|
||||
ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
|
||||
parser
|
||||
.addArgument("--port")
|
||||
.help("Port on which to listen.")
|
||||
.setDefault("8889")
|
||||
.type(String.class)
|
||||
.required(false);
|
||||
parser
|
||||
.addArgument("--host")
|
||||
.help("the IP address of server.")
|
||||
.setDefault("127.0.0.1")
|
||||
.type(String.class)
|
||||
.required(false);
|
||||
parser
|
||||
.addArgument("--audio_in")
|
||||
.help("wav path for decoding.")
|
||||
.setDefault("asr_example.wav")
|
||||
.type(String.class)
|
||||
.required(false);
|
||||
parser
|
||||
.addArgument("--num_threads")
|
||||
.help("num of threads for test.")
|
||||
.setDefault(1)
|
||||
.type(Integer.class)
|
||||
.required(false);
|
||||
parser
|
||||
.addArgument("--chunk_size")
|
||||
.help("chunk size for asr.")
|
||||
.setDefault("5, 10, 5")
|
||||
.type(String.class)
|
||||
.required(false);
|
||||
parser
|
||||
.addArgument("--chunk_interval")
|
||||
.help("chunk for asr.")
|
||||
.setDefault(10)
|
||||
.type(Integer.class)
|
||||
.required(false);
|
||||
|
||||
parser
|
||||
.addArgument("--mode")
|
||||
.help("mode for asr.")
|
||||
.setDefault("offline")
|
||||
.type(String.class)
|
||||
.required(false);
|
||||
String srvIp = "";
|
||||
String srvPort = "";
|
||||
String wavPath = "";
|
||||
int numThreads = 1;
|
||||
String chunk_size = "";
|
||||
int chunk_interval = 10;
|
||||
String strmode = "offline";
|
||||
|
||||
try {
|
||||
Namespace ns = parser.parseArgs(args);
|
||||
srvIp = ns.get("host");
|
||||
srvPort = ns.get("port");
|
||||
wavPath = ns.get("audio_in");
|
||||
numThreads = ns.get("num_threads");
|
||||
chunk_size = ns.get("chunk_size");
|
||||
chunk_interval = ns.get("chunk_interval");
|
||||
strmode = ns.get("mode");
|
||||
System.out.println(srvPort);
|
||||
|
||||
} catch (ArgumentParserException ex) {
|
||||
ex.getParser().handleError(ex);
|
||||
return;
|
||||
}
|
||||
|
||||
FunasrWsClient.strChunkSize = chunk_size;
|
||||
FunasrWsClient.chunkInterval = chunk_interval;
|
||||
FunasrWsClient.wavPath = wavPath;
|
||||
FunasrWsClient.mode = strmode;
|
||||
System.out.println(
|
||||
"serIp="
|
||||
+ srvIp
|
||||
+ ",srvPort="
|
||||
+ srvPort
|
||||
+ ",wavPath="
|
||||
+ wavPath
|
||||
+ ",strChunkSize"
|
||||
+ strChunkSize);
|
||||
|
||||
class ClientThread implements Runnable {
|
||||
|
||||
String srvIp;
|
||||
String srvPort;
|
||||
|
||||
ClientThread(String srvIp, String srvPort, String wavPath) {
|
||||
this.srvIp = srvIp;
|
||||
this.srvPort = srvPort;
|
||||
}
|
||||
|
||||
public void run() {
|
||||
try {
|
||||
|
||||
int RATE = 16000;
|
||||
String[] chunkList = strChunkSize.split(",");
|
||||
int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
|
||||
int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
|
||||
int stride =
|
||||
Integer.valueOf(
|
||||
60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
|
||||
System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
|
||||
System.out.println("CHUNK:" + CHUNK);
|
||||
System.out.println("stride:" + String.valueOf(stride));
|
||||
FunasrWsClient.sendChunkSize = CHUNK * 2;
|
||||
|
||||
String wsAddress = "ws://" + srvIp + ":" + srvPort;
|
||||
|
||||
FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
|
||||
|
||||
c.connect();
|
||||
|
||||
System.out.println("wsAddress:" + wsAddress);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
System.out.println("e:" + e);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < numThreads; i++) {
|
||||
System.out.println("Thread1 is running...");
|
||||
Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
|
||||
t.start();
|
||||
}
|
||||
}
|
||||
}
|
||||
76
funasr/runtime/java/Makefile
Normal file
76
funasr/runtime/java/Makefile
Normal file
@ -0,0 +1,76 @@
|
||||
|
||||
ENTRY_POINT = ./
|
||||
|
||||
|
||||
|
||||
|
||||
WEBSOCKET_DIR:= ./
|
||||
WEBSOCKET_FILES = \
|
||||
$(WEBSOCKET_DIR)/FunasrWsClient.java \
|
||||
|
||||
|
||||
|
||||
LIB_BUILD_DIR = ./lib
|
||||
|
||||
|
||||
|
||||
|
||||
JAVAC = javac
|
||||
|
||||
BUILD_DIR = build
|
||||
|
||||
|
||||
RUNJFLAGS = -Dfile.encoding=utf-8
|
||||
|
||||
|
||||
vpath %.class $(BUILD_DIR)
|
||||
vpath %.java src
|
||||
|
||||
|
||||
|
||||
|
||||
rebuild: clean all
|
||||
|
||||
.PHONY: clean run downjar
|
||||
|
||||
downjar:
|
||||
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
|
||||
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
|
||||
#wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
|
||||
wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
|
||||
wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/
|
||||
wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/
|
||||
rm -frv build
|
||||
mkdir build
|
||||
clean:
|
||||
rm -frv $(BUILD_DIR)/*
|
||||
rm -frv $(LIB_BUILD_DIR)/*
|
||||
mkdir -p $(BUILD_DIR)
|
||||
mkdir -p ./lib
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
runclient:
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
|
||||
|
||||
|
||||
|
||||
buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
|
||||
|
||||
|
||||
%.class: %.java
|
||||
|
||||
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $<
|
||||
|
||||
packjar:
|
||||
jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) .
|
||||
|
||||
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
66
funasr/runtime/java/readme.md
Normal file
66
funasr/runtime/java/readme.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Client for java websocket example
|
||||
|
||||
|
||||
|
||||
## Building for Linux/Unix
|
||||
|
||||
### install java environment
|
||||
```shell
|
||||
# in ubuntu
|
||||
apt-get install openjdk-11-jdk
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Build and run by make
|
||||
|
||||
|
||||
```shell
|
||||
cd funasr/runtime/java
|
||||
# download java lib
|
||||
make downjar
|
||||
# compile
|
||||
make buildwebsocket
|
||||
# run client
|
||||
make runclient
|
||||
|
||||
```
|
||||
|
||||
## Run java websocket client by shell
|
||||
|
||||
```shell
|
||||
# full command refer to Makefile runclient
|
||||
usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
|
||||
[--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
|
||||
|
||||
Where:
|
||||
--host <string>
|
||||
(required) server-ip
|
||||
|
||||
--port <int>
|
||||
(required) port
|
||||
|
||||
--audio_in <string>
|
||||
(required) the wav or pcm file path
|
||||
|
||||
--num_threads <int>
|
||||
thread number for test
|
||||
|
||||
--mode
|
||||
asr mode, support "offline" "online" "2pass"
|
||||
|
||||
|
||||
|
||||
example:
|
||||
FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
|
||||
|
||||
result json, example like:
|
||||
{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"javatest"}
|
||||
```
|
||||
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example.
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
n_total_length += snippet_time;
|
||||
FunASRFreeResult(result);
|
||||
}else{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
LOG(ERROR) << wav_ids[i] << (": No return data!\n");
|
||||
}
|
||||
}
|
||||
{
|
||||
|
||||
@ -18,6 +18,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p
|
||||
|
||||
try{
|
||||
m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
|
||||
LOG(INFO) << "Successfully load model from " << punc_model;
|
||||
}
|
||||
catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load punc onnx model: " << e.what();
|
||||
|
||||
@ -33,6 +33,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
|
||||
|
||||
try {
|
||||
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
|
||||
LOG(INFO) << "Successfully load model from " << am_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am onnx model: " << e.what();
|
||||
exit(0);
|
||||
|
||||
@ -13,7 +13,7 @@ def get_readme():
|
||||
|
||||
|
||||
MODULE_NAME = 'funasr_onnx'
|
||||
VERSION_NUM = '0.1.0'
|
||||
VERSION_NUM = '0.1.1'
|
||||
|
||||
setuptools.setup(
|
||||
name=MODULE_NAME,
|
||||
@ -31,7 +31,7 @@ setuptools.setup(
|
||||
"onnxruntime>=1.7.0",
|
||||
"scipy",
|
||||
"numpy>=1.19.3",
|
||||
"typeguard",
|
||||
"typeguard==2.13.3",
|
||||
"kaldi-native-fbank",
|
||||
"PyYAML>=5.1.2",
|
||||
"funasr",
|
||||
|
||||
@ -3,7 +3,7 @@ generated certificate may not suitable for all browsers due to security concerns
|
||||
|
||||
```shell
|
||||
### 1) Generate a private key
|
||||
openssl genrsa -des3 -out server.key 1024
|
||||
openssl genrsa -des3 -out server.key 2048
|
||||
|
||||
### 2) Generate a csr file
|
||||
openssl req -new -key server.key -out server.csr
|
||||
@ -14,4 +14,4 @@ openssl rsa -in server.key.org -out server.key
|
||||
|
||||
### 4) Generated a crt file, valid for 1 year
|
||||
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
|
||||
```
|
||||
```
|
||||
|
||||
@ -1,15 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx
|
||||
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n
|
||||
MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh
|
||||
YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD
|
||||
TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH
|
||||
YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG
|
||||
SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS
|
||||
NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc
|
||||
38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL
|
||||
mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA
|
||||
dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB
|
||||
uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS
|
||||
hXtozVCgkSJzX6uD
|
||||
MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
|
||||
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
|
||||
MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
|
||||
bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
|
||||
DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
|
||||
EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
|
||||
aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
|
||||
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
|
||||
qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
|
||||
f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
|
||||
EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
|
||||
Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
|
||||
LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
|
||||
AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
|
||||
D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
|
||||
dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
|
||||
Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
|
||||
7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
|
||||
Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
@ -1,15 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u
|
||||
cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso
|
||||
OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB
|
||||
AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq
|
||||
pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc
|
||||
2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE
|
||||
NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD
|
||||
M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3
|
||||
7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt
|
||||
w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z
|
||||
fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l
|
||||
SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO
|
||||
uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB
|
||||
MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
|
||||
vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
|
||||
5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
|
||||
jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
|
||||
ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
|
||||
NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
|
||||
MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
|
||||
vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
|
||||
jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
|
||||
eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
|
||||
C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
|
||||
6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
|
||||
jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
|
||||
qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
|
||||
aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
|
||||
zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
|
||||
xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
|
||||
9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
|
||||
3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
|
||||
P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
|
||||
7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
|
||||
cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
|
||||
dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
|
||||
cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
|
||||
W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
|
||||
-----END RSA PRIVATE KEY-----
|
||||
|
||||
@ -91,7 +91,6 @@ class WebsocketClient {
|
||||
using websocketpp::lib::placeholders::_1;
|
||||
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
|
||||
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
|
||||
// m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
|
||||
|
||||
m_client.set_message_handler(
|
||||
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
|
||||
@ -218,7 +217,7 @@ class WebsocketClient {
|
||||
}
|
||||
}
|
||||
if (wait) {
|
||||
LOG(INFO) << "wait.." << m_open;
|
||||
// LOG(INFO) << "wait.." << m_open;
|
||||
WaitABit();
|
||||
continue;
|
||||
}
|
||||
@ -292,7 +291,7 @@ int main(int argc, char* argv[]) {
|
||||
false, 1, "int");
|
||||
TCLAP::ValueArg<int> is_ssl_(
|
||||
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
|
||||
false, 0, "int");
|
||||
false, 1, "int");
|
||||
|
||||
cmd.add(server_ip_);
|
||||
cmd.add(port_);
|
||||
|
||||
@ -38,6 +38,7 @@ from funasr.models.decoder.transformer_decoder import (
|
||||
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
|
||||
from funasr.models.e2e_asr import ASRModel
|
||||
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
||||
from funasr.models.joint_net.joint_network import JointNetwork
|
||||
@ -45,6 +46,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, Paraf
|
||||
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
||||
from funasr.models.e2e_tp import TimestampPredictor
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_sa_asr import SAASRModel
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
@ -54,6 +56,7 @@ from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.default import MultiChannelFrontend
|
||||
@ -134,6 +137,7 @@ model_choices = ClassChoices(
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
sa_asr=SAASRModel,
|
||||
),
|
||||
type_check=FunASRModel,
|
||||
default="asr",
|
||||
@ -175,6 +179,27 @@ encoder_choices2 = ClassChoices(
|
||||
type_check=AbsEncoder,
|
||||
default="rnn",
|
||||
)
|
||||
asr_encoder_choices = ClassChoices(
|
||||
"asr_encoder",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
mfcca_enc=MFCCAEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="rnn",
|
||||
)
|
||||
spk_encoder_choices = ClassChoices(
|
||||
"spk_encoder",
|
||||
classes=dict(
|
||||
resnet34_diar=ResNet34Diar,
|
||||
),
|
||||
default="resnet34_diar",
|
||||
)
|
||||
postencoder_choices = ClassChoices(
|
||||
name="postencoder",
|
||||
classes=dict(
|
||||
@ -197,6 +222,7 @@ decoder_choices = ClassChoices(
|
||||
paraformer_decoder_sanm=ParaformerSANMDecoder,
|
||||
paraformer_decoder_san=ParaformerDecoderSAN,
|
||||
contextual_paraformer_decoder=ContextualParaformerDecoder,
|
||||
sa_decoder=SAAsrTransformerDecoder,
|
||||
),
|
||||
type_check=AbsDecoder,
|
||||
default="rnn",
|
||||
@ -329,6 +355,12 @@ class ASRTask(AbsTask):
|
||||
default=True,
|
||||
help="whether to split text using <space>",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max_spk_num",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--seg_dict_file",
|
||||
type=str,
|
||||
@ -1495,3 +1527,123 @@ class ASRTransducerTask(ASRTask):
|
||||
#assert check_return_type(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ASRTaskSAASR(ASRTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
# --asr_encoder and --asr_encoder_conf
|
||||
asr_encoder_choices,
|
||||
# --spk_encoder and --spk_encoder_conf
|
||||
spk_encoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 5. Encoder
|
||||
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
|
||||
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
|
||||
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
|
||||
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
|
||||
|
||||
# 7. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=asr_encoder.output_size(),
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
# 8. CTC
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
|
||||
)
|
||||
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# 9. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("asr")
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
asr_encoder=asr_encoder,
|
||||
spk_encoder=spk_encoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 10. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import (
|
||||
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.e2e_sa_asr import ESPnetASRModel
|
||||
from funasr.models.e2e_sa_asr import SAASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
from funasr.models.e2e_tp import TimestampPredictor
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
@ -120,7 +120,7 @@ normalize_choices = ClassChoices(
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
asr=ESPnetASRModel,
|
||||
asr=SAASRModel,
|
||||
uniasr=UniASR,
|
||||
paraformer=Paraformer,
|
||||
paraformer_bert=ParaformerBert,
|
||||
@ -620,4 +620,4 @@ class ASRTask(AbsTask):
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
return model
|
||||
@ -1 +1 @@
|
||||
0.6.1
|
||||
0.6.2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user