Merge branch 'main' into dev_wjm_infer

This commit is contained in:
jmwang66 2023-06-19 20:28:23 +08:00 committed by GitHub
commit 4e0fcee2a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 1968 additions and 160 deletions

View File

@ -72,8 +72,8 @@ If you have any questions about FunASR, please contact us by
## Contributors
| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/DeepScience.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:|
| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|
## Acknowledge
@ -82,7 +82,6 @@ If you have any questions about FunASR, please contact us by
3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training.
4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime.
5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime.
6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5.
## License

View File

@ -1,29 +0,0 @@
lm: transformer
lm_conf:
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.1
# optimization related
grad_clip: 5.0
batch_type: numel
batch_bins: 500000 # 4gpus * 500000
accum_grad: 1
max_epoch: 15 # 15epoch is enougth
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 10 # 10 is good.

View File

@ -0,0 +1,86 @@
# Get Started
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
# Train
First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
```shell
dataset
|—— Eval_Ali_far
|—— Eval_Ali_near
|—— Test_Ali_far
|—— Test_Ali_near
|—— Train_Ali_far
|—— Train_Ali_near
```
Then you can run this receipe by running:
```shell
bash run.sh --stage 0 --stop-stage 6
```
There are 8 stages in `run.sh`:
```shell
stage 0: Data preparation and remove the audio which is too long or too short.
stage 1: Speaker profile and CMVN Generation.
stage 2: Dictionary preparation.
stage 3: LM training (not supported).
stage 4: ASR Training.
stage 5: SA-ASR Training.
stage 6: Inference
stage 7: Inference with Test_2023_Ali_far
```
<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
# Infer
1. Download the final test set and extracted
2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`.
4. Run the `run.sh` as follow.
```shell
# Prepare test_2023 set
bash run.sh --stage 0 --stop-stage 1
# Decode test_2023 set
bash run.sh --stage 7 --stop-stage 7
```
# Format of Final Submission
Finally, you need to submit a file called `text_spk_merge` with the following format:
```shell
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
...
```
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
# Baseline Results
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
<!-- <table>
<tr >
<td rowspan="2"></td>
<td colspan="2">SI-CER(%)</td>
<td colspan="2">cpCER(%)</td>
</tr>
<tr>
<td>Eval</td>
<td>Test</td>
<td>Eval</td>
<td>Test</td>
</tr>
<tr>
<td>oracle profile</td>
<td>32.05</td>
<td>32.72</td>
<td>47.40</td>
<td>42.92</td>
</tr>
<tr>
<td>cluster profile</td>
<td>32.05</td>
<td>32.73</td>
<td>53.76</td>
<td>49.37</td>
</tr>
</table> -->
| |SI-CER(%) |cp-CER(%) |
|:---------------|:------------:|----------:|
|oracle profile |32.72 |42.92 |
|cluster profile|32.73 |49.37 |
# Reference
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 44134417.

View File

@ -0,0 +1,102 @@
# network architecture
frontend: multichannelfrontend
frontend_conf:
fs: 16000
window: hann
n_fft: 400
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 1
lfr_n: 1
use_channel: 0
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
macaron_style: true
use_cnn_module: true
cnn_module_kernel: 15
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# ctc related
ctc_conf:
ignore_nan_grad: true
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
dataset_conf:
data_names: speech,text
data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
sort_size: 500
batch_conf:
batch_type: token
batch_size: 7000
num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2

View File

@ -0,0 +1,131 @@
# network architecture
frontend: multichannelfrontend
frontend_conf:
fs: 16000
window: hann
n_fft: 400
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 1
lfr_n: 1
use_channel: 0
# encoder related
asr_encoder: conformer
asr_encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
macaron_style: true
use_cnn_module: true
cnn_module_kernel: 15
spk_encoder: resnet34_diar
spk_encoder_conf:
use_head_conv: true
batchnorm_momentum: 0.5
use_head_maxpool: false
num_nodes_pooling_layer: 256
layers_in_block:
- 3
- 4
- 6
- 3
filters_in_block:
- 32
- 64
- 128
- 256
pooling_type: statistic
num_nodes_resnet1: 256
num_nodes_last_layer: 256
batchnorm_momentum: 0.5
# decoder related
decoder: sa_decoder
decoder_conf:
attention_heads: 4
linear_units: 2048
asr_num_blocks: 6
spk_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
spk_weight: 0.5
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
max_spk_num: 4
ctc_conf:
ignore_nan_grad: true
# minibatch related
dataset_conf:
data_names: speech,text,profile,text_id
data_types: sound,text,npy,text_int
shuffle: True
shuffle_conf:
shuffle_size: 2048
sort_size: 500
batch_conf:
batch_type: token
batch_size: 7000
num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 60
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- acc
- max
- - valid
- acc_spk
- max
- - valid
- loss
- min
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.0005
scheduler: warmuplr
scheduler_conf:
warmup_steps: 8000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2

View File

@ -21,6 +21,8 @@ EOF
SECONDS=0
tgt=Train #Train or Eval
min_wav_duration=0.1
max_wav_duration=20
log "$0 $*"
@ -57,27 +59,24 @@ stage=1
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
mkdir -p data/org
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting near dir"
find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
n1_wav=$(wc -l < $near_dir/wavlist)
n2_text=$(wc -l < $near_dir/textgrid.flist)
log near file found $n1_wav wav and $n2_text text.
paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
# cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@ -97,9 +96,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
@ -119,28 +116,28 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: finali data process"
log "stage 3: final data process"
local/fix_data_dir.sh $near_dir
local/fix_data_dir.sh $far_dir
local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
cp data/org/${x}/text data/org/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
> data/org/${x}/text
rm data/org/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@ -150,14 +147,15 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
cp data/org/${x}/text data/org/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
> data/org/${x}/text
rm data/org/${x}/text.org
done
rm -rf data/local
log "Successfully finished. [elapsed=${SECONDS}s]"
fi

View File

@ -0,0 +1,134 @@
import argparse
import json
import os
import numpy as np
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from funasr.models.frontend.default import DefaultFrontend
import torch
def get_parser():
parser = argparse.ArgumentParser(
description="computer global cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dim",
default=80,
type=int,
help="feature dimension",
)
parser.add_argument(
"--wav_path",
default=False,
required=True,
type=str,
help="the path of wav scps",
)
parser.add_argument(
"--config_file",
type=str,
help="the config file for computing cmvn",
)
parser.add_argument(
"--idx",
default=1,
required=True,
type=int,
help="index",
)
return parser
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
resample_rate=16000,
speed=1.0,
window_type="hamming"):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, resample_rate,
[['speed', str(speed)], ['rate', str(resample_rate)]]
)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type=window_type,
sample_frequency=resample_rate)
return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
mean_stats = np.zeros(args.dim)
var_stats = np.zeros(args.dim)
total_frames = 0
# with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
# for key, mat in ark_reader:
# mean_stats += np.sum(mat, axis=0)
# var_stats += np.sum(np.square(mat), axis=0)
# total_frames += mat.shape[0]
with open(args.config_file) as f:
configs = yaml.safe_load(f)
frontend_configs = configs.get("frontend_conf", {})
num_mel_bins = frontend_configs.get("n_mels", 80)
frame_length = frontend_configs.get("frame_length", 25)
frame_shift = frontend_configs.get("frame_shift", 10)
window_type = frontend_configs.get("window", "hamming")
resample_rate = frontend_configs.get("fs", 16000)
n_fft = frontend_configs.get("n_fft", "400")
use_channel = frontend_configs.get("use_channel", None)
assert num_mel_bins == args.dim
frontend = DefaultFrontend(
fs=resample_rate,
n_fft=n_fft,
win_length=frame_length * 16,
hop_length=frame_shift * 16,
window=window_type,
n_mels=num_mel_bins,
use_channel=use_channel,
)
with open(wav_scp_file) as f:
lines = f.readlines()
for line in lines:
_, wav_file = line.strip().split()
wavform, _ = torchaudio.load(wav_file)
fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
fbank = fbank.squeeze(0).numpy()
mean_stats += np.sum(fbank, axis=0)
var_stats += np.sum(np.square(fbank), axis=0)
total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
'var_stats': list(var_stats.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,39 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
fbankdir=
nj=32
cmd=./utils/run.pl
feats_dim=80
config_file=
scale=1.0
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# shellcheck disable=SC2046
head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
split_dir=${fbankdir}/cmvn/split_${nj};
mkdir -p $split_dir
split_scps=""
for n in $(seq $nj); do
split_scps="$split_scps $split_dir/wav.$n.scp"
done
utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
python local/compute_cmvn.py \
--dim ${feats_dim} \
--wav_path $split_dir \
--config_file $config_file \
--idx JOB \
python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
echo "$0: Succeeded compute global cmvn"

View File

@ -0,0 +1,29 @@
import codecs
import pdb
import sys
import torch
char1 = sys.argv[1]
char2 = sys.argv[2]
model1 = torch.load(sys.argv[3], map_location='cpu')
model2_path = sys.argv[4]
d_new = model1
char1_list = []
map_list = []
with codecs.open(char1) as f:
for line in f.readlines():
char1_list.append(line.strip())
with codecs.open(char2) as f:
for line in f.readlines():
map_list.append(char1_list.index(line.strip()))
print(map_list)
for k, v in d_new.items():
if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
d_new[k] = v[map_list]
torch.save(d_new, model2_path)

View File

@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
# 2017 Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
remove_archive=true
shift
fi
if [ $# -ne 3 ]; then
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
echo "$0: no such directory $data"
exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
exit 1;
fi
if [ -z "$url" ]; then
echo "$0: empty URL base."
exit 1;
fi
if [ -f $data/$part/.complete ]; then
echo "$0: data part $part was already successfully extracted, nothing to do."
exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
size_ok=false
for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
if ! $size_ok; then
echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
echo "does not equal the size of one of the archives."
rm $data/$part.tgz
else
echo "$data/$part.tgz exists and appears to be complete."
fi
fi
if [ ! -f $data/$part.tgz ]; then
if ! command -v wget >/dev/null; then
echo "$0: wget is not installed."
exit 1;
fi
full_url=$url/$part.tgz
echo "$0: downloading data from $full_url. This may take some time, please be patient."
cd $data || exit 1
if ! wget --no-check-certificate $full_url; then
echo "$0: error executing wget $full_url"
exit 1;
fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
echo "$0: error un-tarring archive $data/$part.tgz"
exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
cd $data/$part/wav || exit 1
for wav in ./*.tar.gz; do
echo "Extracting wav from $wav"
tar -zxf $wav && rm $wav
done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
rm $data/$part.tgz
fi
exit 0;

View File

@ -63,7 +63,7 @@ if __name__ == "__main__":
wav_scp_file = open(path+'/wav.scp', 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@ -92,8 +92,8 @@ if __name__ == "__main__":
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
meeting_map = {}
for line in raw_meeting_scp:
meeting = line.strip().split('\t')[0]
wav_path = line.strip().split('\t')[1]
meeting = line.strip().split(' ')[0]
wav_path = line.strip().split(' ')[1]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:

View File

@ -9,7 +9,7 @@ import soundfile
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@ -22,8 +22,8 @@ if __name__=="__main__":
raw_wav_map = {}
for line in raw_meeting_scp:
meeting = line.strip().split('\t')[0]
wav_path = line.strip().split('\t')[1]
meeting = line.strip().split(' ')[0]
wav_path = line.strip().split(' ')[1]
raw_wav_map[meeting] = wav_path
spk_map = {}

View File

@ -5,7 +5,7 @@ import sys
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Train_Ali_far
path = sys.argv[1]
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
@ -29,7 +29,7 @@ if __name__=="__main__":
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
spk = meeting+'_' + spk_id
spk = meeting + '_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)

View File

@ -30,8 +30,7 @@ def main(args):
meetingid_map = {}
for line in spk2utt:
spkid = line.strip().split(" ")[0]
meeting_id_list = spkid.split("_")[:3]
meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
meeting_id = spkid.split("-")[0]
if meeting_id not in meetingid_map:
meetingid_map[meeting_id] = 1
else:

6
egs/alimeeting/sa_asr/path.sh Executable file
View File

@ -0,0 +1,6 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH

435
egs/alimeeting/sa_asr/run.sh Executable file
View File

@ -0,0 +1,435 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="6,7"
gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=8
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="data" #feature output dictionary
exp_dir="exp"
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="1.0"
min_wav_duration=0.1
max_wav_duration=20
profile_modes="cluster oracle"
stage=7
stop_stage=7
# feature configuration
feats_dim=80
nj=32
# data
raw_data=
data_url=
# exp tag
tag=""
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=Train_Ali_far
valid_set=Eval_Ali_far
test_sets="Test_Ali_far Eval_Ali_far"
test_2023="Test_2023_Ali_far_release"
asr_config=conf/train_asr_conformer.yaml
sa_asr_config=conf/train_sa_asr_conformer.yaml
asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_rnn.yaml
inference_sa_asr_model=valid.acc_spk.ave.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
# remove long/short data
for x in ${train_set} ${valid_set} ${test_sets}; do
cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format wav --segments ${feats_dir}/org/${x}/segments \
"${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
_min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
_max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
<"${feats_dir}/${x}/utt2num_samples" \
awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
>"${feats_dir}/${x}/utt2num_samples_rmls"
mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
<"${feats_dir}/${x}/wav.scp" \
utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \
>"${feats_dir}/${x}/wav.scp_rmls"
mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
<"${feats_dir}/${x}/text" \
awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
<"${feats_dir}/${x}/utt2spk_all_fifo" \
utils/filter_scp.pl "${feats_dir}/${x}/text" \
>"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
done
for x in ${test_2023}; do
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format wav --segments ${feats_dir}/org/${x}/segments \
"${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Speaker profile and CMVN Generation"
mkdir -p "profile_log"
for x in "${train_set}" "${valid_set}" "${test_sets}"; do
# generate text_id spk2id
python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
# generate text_id_train for sot
python local/process_text_id.py ${feats_dir}/${x}
echo "Successfully generate ${feats_dir}/${x}/text_id_train"
# generate oracle_embedding from single-speaker audio segment
echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
# generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
if [ "${x}" = "${train_set}" ]; then
python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
else
python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
fi
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
fi
# compute CMVN
if [ "${x}" = "${train_set}" ]; then
local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
fi
done
for x in "${test_2023}"; do
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
done
fi
token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/${lang}_token_list/char/
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: LM Training"
fi
# ASR Training Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Stage 4: ASR Training"
asr_exp=${exp_dir}/${asr_model_dir}
mkdir -p ${asr_exp}
mkdir -p ${asr_exp}/log
INIT_FILE=${asr_exp}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $ngpu; ++i)); do
{
# i=0
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name asr \
--model asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--split_with_space false \
--token_type char \
--token_list $token_list \
--data_dir ${feats_dir} \
--train_set ${train_set} \
--valid_set ${valid_set} \
--data_file_names "wav.scp,text" \
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/${asr_model_dir} \
--config $asr_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "SA-ASR training"
asr_exp=${exp_dir}/${asr_model_dir}
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
mkdir -p ${sa_asr_exp}
mkdir -p ${sa_asr_exp}/log
INIT_FILE=${sa_asr_exp}/ddp_init
if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
fi
if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
# download xvector extractor model file
python local/download_xvector_model.py ${exp_dir}
echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
fi
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $ngpu; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name asr \
--model sa_asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--split_with_space false \
--unused_parameters true \
--token_type char \
--resume true \
--token_list $token_list \
--data_dir ${feats_dir} \
--train_set ${train_set} \
--valid_set ${valid_set} \
--data_file_names "wav.scp,text,profile.scp,text_id_train" \
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
--init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
--init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
--init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
--init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
--output_dir ${exp_dir}/${sa_asr_model_dir} \
--config $sa_asr_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "stage 6: Inference test sets"
for x in ${test_sets}; do
for profile_mode in ${profile_modes}; do
echo "decoding ${x} with ${profile_mode} profile"
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/${x}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
if [ $profile_mode = "oracle" ]; then
profile_scp=${profile_mode}_profile_nopadding.scp
else
profile_scp=${profile_mode}_profile_infer.scp
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--ngpu "${_ngpu}" \
--njob ${njob} \
--nbest 1 \
--gpuid_list ${gpuid_list} \
--allow_variable_data_keys true \
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${sa_asr_exp}"/config.yaml \
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode sa_asr \
${_opts}
for f in token token_int score text text_id; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
python local/process_text_spk_merge.py ${_dir}
python local/process_text_spk_merge.py ${_data}
python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
cat ${_dir}/text.cpcer.txt
done
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "stage 7: Inference test 2023"
for x in ${test_2023}; do
sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/${x}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--ngpu "${_ngpu}" \
--njob ${njob} \
--nbest 1 \
--gpuid_list ${gpuid_list} \
--allow_variable_data_keys true \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
--cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${sa_asr_exp}"/config.yaml \
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode sa_asr \
${_opts}
for f in token token_int score text text_id; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python local/process_text_spk_merge.py ${_dir}
done
fi

View File

@ -0,0 +1,6 @@
beam_size: 20
penalty: 0.0
maxlenratio: 0.0
minlenratio: 0.0
ctc_weight: 0.6
lm_weight: 0.3

View File

@ -0,0 +1 @@
../sa_asr/local/

View File

@ -0,0 +1 @@
../../aishell/transformer/utils

View File

@ -1636,8 +1636,10 @@ class Speech2TextSAASR:
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend == 'wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
from funasr.tasks.sa_asr import frontend_choices
if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()

View File

@ -619,7 +619,12 @@ def inference_paraformer_vad_punc(
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
batch_size_token_ms = batch_size_token * 60
batch_size_token_ms = batch_size_token*60
if speech2text.device == "cpu":
batch_size_token_ms = 0
batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):

View File

@ -301,7 +301,7 @@ def get_parser():
"--freeze_param",
type=str,
default=[],
nargs="*",
action="append",
help="Freeze parameters",
)

View File

@ -6,7 +6,6 @@ from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
@ -20,17 +19,23 @@ from funasr.models.decoder.transformer_decoder import (
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
ContextualParaformer
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@ -93,6 +98,8 @@ model_choices = ClassChoices(
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
sa_asr=SAASRModel,
),
default="asr",
)
@ -110,6 +117,27 @@ encoder_choices = ClassChoices(
),
default="rnn",
)
asr_encoder_choices = ClassChoices(
"asr_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
default="rnn",
)
spk_encoder_choices = ClassChoices(
"spk_encoder",
classes=dict(
resnet34_diar=ResNet34Diar,
),
default="resnet34_diar",
)
encoder_choices2 = ClassChoices(
"encoder2",
classes=dict(
@ -134,6 +162,7 @@ decoder_choices = ClassChoices(
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
sa_decoder=SAAsrTransformerDecoder,
),
default="rnn",
)
@ -225,6 +254,10 @@ class_choices_list = [
rnnt_decoder_choices,
# --joint_network and --joint_network_conf
joint_network_choices,
# --asr_encoder and --asr_encoder_conf
asr_encoder_choices,
# --spk_encoder and --spk_encoder_conf
spk_encoder_choices,
]
@ -247,7 +280,7 @@ def build_asr_model(args):
# frontend
if hasattr(args, "input_size") and args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
@ -425,6 +458,33 @@ def build_asr_model(args):
joint_network=joint_network,
**args.model_conf,
)
elif args.model == "sa_asr":
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=asr_encoder.output_size(),
**args.decoder_conf,
)
ctc = CTC(
odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
)
model_class = model_choices.get_class(args.model)
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
asr_encoder=asr_encoder,
spk_encoder=spk_encoder,
decoder=decoder,
ctc=ctc,
token_list=token_list,
**args.model_conf,
)
else:
raise NotImplementedError("Not supported model: {}".format(args.model))

View File

@ -40,7 +40,7 @@ else:
yield
class ESPnetASRModel(FunASRModel):
class SAASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@ -51,10 +51,8 @@ class ESPnetASRModel(FunASRModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
@ -89,8 +87,6 @@ class ESPnetASRModel(FunASRModel):
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
@ -293,10 +289,6 @@ class ESPnetASRModel(FunASRModel):
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@ -317,11 +309,6 @@ class ESPnetASRModel(FunASRModel):
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@ -337,7 +324,7 @@ class ESPnetASRModel(FunASRModel):
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
return encoder_out, encoder_out_lens, encoder_out_spk

View File

@ -2,7 +2,7 @@ import copy
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import humanfriendly
import numpy as np
import torch
@ -14,6 +14,7 @@ from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask
class DefaultFrontend(AbsFrontend):
@ -137,8 +138,6 @@ class DefaultFrontend(AbsFrontend):
return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@ -147,9 +146,9 @@ class MultiChannelFrontend(AbsFrontend):
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
n_fft: int = 400,
frame_length: int = 25,
frame_shift: int = 10,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
@ -160,10 +159,10 @@ class MultiChannelFrontend(AbsFrontend):
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
frame_length: int = None,
frame_shift: int = None,
lfr_m: int = None,
lfr_n: int = None,
use_channel: int = None,
lfr_m: int = 1,
lfr_n: int = 1,
cmvn_file: str = None
):
assert check_argument_types()
super().__init__()
@ -172,13 +171,14 @@ class MultiChannelFrontend(AbsFrontend):
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
self.win_length = frame_length * 16
self.hop_length = frame_shift * 16
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
win_length=self.win_length,
hop_length=self.hop_length,
center=center,
window=window,
normalized=normalized,
@ -202,7 +202,17 @@ class MultiChannelFrontend(AbsFrontend):
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "multichannelfrontend"
self.frontend_type = "default"
self.use_channel = use_channel
if self.use_channel is not None:
logging.info("use the channel %d" % (self.use_channel))
else:
logging.info("random select channel")
self.cmvn_file = cmvn_file
if self.cmvn_file is not None:
mean, std = self._load_cmvn(self.cmvn_file)
self.register_buffer("mean", torch.from_numpy(mean))
self.register_buffer("std", torch.from_numpy(std))
def output_size(self) -> int:
return self.n_mels
@ -215,16 +225,29 @@ class MultiChannelFrontend(AbsFrontend):
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
if isinstance(input, ComplexTensor):
input_stft = input
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel is not None:
input_stft = input_stft[:, :, self.use_channel, :]
else:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
@ -233,18 +256,27 @@ class MultiChannelFrontend(AbsFrontend):
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
bt = input_feats.size(0)
if input_feats.dim() ==4:
channel_size = input_feats.size(2)
# batch * channel * T * D
#pdb.set_trace()
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
# input_feats = input_feats.transpose(1,2)
# batch * channel
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
else:
channel_size = 1
return input_feats, feats_lens, channel_size
# 6. Apply CMVN
if self.cmvn_file is not None:
if feats_lens is None:
feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
self.mean = self.mean.to(input_feats.device, input_feats.dtype)
self.std = self.std.to(input_feats.device, input_feats.dtype)
mask = make_pad_mask(feats_lens, input_feats, 1)
if input_feats.requires_grad:
input_feats = input_feats + self.mean
else:
input_feats += self.mean
if input_feats.requires_grad:
input_feats = input_feats.masked_fill(mask, 0.0)
else:
input_feats.masked_fill_(mask, 0.0)
input_feats *= self.std
return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
@ -258,4 +290,27 @@ class MultiChannelFrontend(AbsFrontend):
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens
return input_stft, feats_lens
def _load_cmvn(self, cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
return means, vars

View File

@ -0,0 +1,344 @@
//
// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
// Reserved. MIT License (https://opensource.org/licenses/MIT)
//
/*
* // 2022-2023 by zhaomingwork@qq.com
*/
// java FunasrWsClient
// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
package websocket;
import java.io.*;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.*;
import java.util.Map;
import net.sourceforge.argparse4j.ArgumentParsers;
import net.sourceforge.argparse4j.inf.ArgumentParser;
import net.sourceforge.argparse4j.inf.ArgumentParserException;
import net.sourceforge.argparse4j.inf.Namespace;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft;
import org.java_websocket.handshake.ServerHandshake;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** This example demonstrates how to connect to websocket server. */
public class FunasrWsClient extends WebSocketClient {
public class RecWavThread extends Thread {
private FunasrWsClient funasrClient;
public RecWavThread(FunasrWsClient funasrClient) {
this.funasrClient = funasrClient;
}
public void run() {
this.funasrClient.recWav();
}
}
private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
public FunasrWsClient(URI serverUri, Draft draft) {
super(serverUri, draft);
}
public FunasrWsClient(URI serverURI) {
super(serverURI);
}
public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
super(serverUri, httpHeaders);
}
public void getSslContext(String keyfile, String certfile) {
// TODO
return;
}
// send json at first time
public void sendJson(
String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
try {
JSONObject obj = new JSONObject();
obj.put("mode", mode);
JSONArray array = new JSONArray();
String[] chunkList = strChunkSize.split(",");
for (int i = 0; i < chunkList.length; i++) {
array.add(Integer.valueOf(chunkList[i].trim()));
}
obj.put("chunk_size", array);
obj.put("chunk_interval", new Integer(chunkInterval));
obj.put("wav_name", wavName);
if (isSpeaking) {
obj.put("is_speaking", new Boolean(true));
} else {
obj.put("is_speaking", new Boolean(false));
}
logger.info("sendJson: " + obj);
// return;
send(obj.toString());
return;
} catch (Exception e) {
e.printStackTrace();
}
}
// send json at end of wav
public void sendEof() {
try {
JSONObject obj = new JSONObject();
obj.put("is_speaking", new Boolean(false));
logger.info("sendEof: " + obj);
// return;
send(obj.toString());
iseof = true;
return;
} catch (Exception e) {
e.printStackTrace();
}
}
// function for rec wav file
public void recWav() {
sendJson(mode, strChunkSize, chunkInterval, wavName, true);
File file = new File(FunasrWsClient.wavPath);
int chunkSize = sendChunkSize;
byte[] bytes = new byte[chunkSize];
int readSize = 0;
try (FileInputStream fis = new FileInputStream(file)) {
if (FunasrWsClient.wavPath.endsWith(".wav")) {
fis.read(bytes, 0, 44); //skip first 44 wav header
}
readSize = fis.read(bytes, 0, chunkSize);
while (readSize > 0) {
// send when it is chunk size
if (readSize == chunkSize) {
send(bytes); // send buf to server
} else {
// send when at last or not is chunk size
byte[] tmpBytes = new byte[readSize];
for (int i = 0; i < readSize; i++) {
tmpBytes[i] = bytes[i];
}
send(tmpBytes);
}
// if not in offline mode, we simulate online stream by sleep
if (!mode.equals("offline")) {
Thread.sleep(Integer.valueOf(chunkSize / 32));
}
readSize = fis.read(bytes, 0, chunkSize);
}
if (!mode.equals("offline")) {
// if not offline, we send eof and wait for 3 seconds to close
Thread.sleep(2000);
sendEof();
Thread.sleep(3000);
close();
} else {
// if offline, just send eof
sendEof();
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void onOpen(ServerHandshake handshakedata) {
RecWavThread thread = new RecWavThread(this);
thread.start();
}
@Override
public void onMessage(String message) {
JSONObject jsonObject = new JSONObject();
JSONParser jsonParser = new JSONParser();
logger.info("received: " + message);
try {
jsonObject = (JSONObject) jsonParser.parse(message);
logger.info("text: " + jsonObject.get("text"));
} catch (org.json.simple.parser.ParseException e) {
e.printStackTrace();
}
if (iseof && mode.equals("offline")) {
close();
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
logger.info(
"Connection closed by "
+ (remote ? "remote peer" : "us")
+ " Code: "
+ code
+ " Reason: "
+ reason);
}
@Override
public void onError(Exception ex) {
logger.info("ex: " + ex);
ex.printStackTrace();
// if the error is fatal then onClose will be called additionally
}
private boolean iseof = false;
public static String wavPath;
static String mode = "online";
static String strChunkSize = "5,10,5";
static int chunkInterval = 10;
static int sendChunkSize = 1920;
String wavName = "javatest";
public static void main(String[] args) throws URISyntaxException {
ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
parser
.addArgument("--port")
.help("Port on which to listen.")
.setDefault("8889")
.type(String.class)
.required(false);
parser
.addArgument("--host")
.help("the IP address of server.")
.setDefault("127.0.0.1")
.type(String.class)
.required(false);
parser
.addArgument("--audio_in")
.help("wav path for decoding.")
.setDefault("asr_example.wav")
.type(String.class)
.required(false);
parser
.addArgument("--num_threads")
.help("num of threads for test.")
.setDefault(1)
.type(Integer.class)
.required(false);
parser
.addArgument("--chunk_size")
.help("chunk size for asr.")
.setDefault("5, 10, 5")
.type(String.class)
.required(false);
parser
.addArgument("--chunk_interval")
.help("chunk for asr.")
.setDefault(10)
.type(Integer.class)
.required(false);
parser
.addArgument("--mode")
.help("mode for asr.")
.setDefault("offline")
.type(String.class)
.required(false);
String srvIp = "";
String srvPort = "";
String wavPath = "";
int numThreads = 1;
String chunk_size = "";
int chunk_interval = 10;
String strmode = "offline";
try {
Namespace ns = parser.parseArgs(args);
srvIp = ns.get("host");
srvPort = ns.get("port");
wavPath = ns.get("audio_in");
numThreads = ns.get("num_threads");
chunk_size = ns.get("chunk_size");
chunk_interval = ns.get("chunk_interval");
strmode = ns.get("mode");
System.out.println(srvPort);
} catch (ArgumentParserException ex) {
ex.getParser().handleError(ex);
return;
}
FunasrWsClient.strChunkSize = chunk_size;
FunasrWsClient.chunkInterval = chunk_interval;
FunasrWsClient.wavPath = wavPath;
FunasrWsClient.mode = strmode;
System.out.println(
"serIp="
+ srvIp
+ ",srvPort="
+ srvPort
+ ",wavPath="
+ wavPath
+ ",strChunkSize"
+ strChunkSize);
class ClientThread implements Runnable {
String srvIp;
String srvPort;
ClientThread(String srvIp, String srvPort, String wavPath) {
this.srvIp = srvIp;
this.srvPort = srvPort;
}
public void run() {
try {
int RATE = 16000;
String[] chunkList = strChunkSize.split(",");
int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
int stride =
Integer.valueOf(
60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
System.out.println("CHUNK:" + CHUNK);
System.out.println("stride:" + String.valueOf(stride));
FunasrWsClient.sendChunkSize = CHUNK * 2;
String wsAddress = "ws://" + srvIp + ":" + srvPort;
FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
c.connect();
System.out.println("wsAddress:" + wsAddress);
} catch (Exception e) {
e.printStackTrace();
System.out.println("e:" + e);
}
}
}
for (int i = 0; i < numThreads; i++) {
System.out.println("Thread1 is running...");
Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
t.start();
}
}
}

View File

@ -0,0 +1,76 @@
ENTRY_POINT = ./
WEBSOCKET_DIR:= ./
WEBSOCKET_FILES = \
$(WEBSOCKET_DIR)/FunasrWsClient.java \
LIB_BUILD_DIR = ./lib
JAVAC = javac
BUILD_DIR = build
RUNJFLAGS = -Dfile.encoding=utf-8
vpath %.class $(BUILD_DIR)
vpath %.java src
rebuild: clean all
.PHONY: clean run downjar
downjar:
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
#wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/
wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/
rm -frv build
mkdir build
clean:
rm -frv $(BUILD_DIR)/*
rm -frv $(LIB_BUILD_DIR)/*
mkdir -p $(BUILD_DIR)
mkdir -p ./lib
runclient:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
%.class: %.java
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $<
packjar:
jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) .
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket

View File

@ -0,0 +1,66 @@
# Client for java websocket example
## Building for Linux/Unix
### install java environment
```shell
# in ubuntu
apt-get install openjdk-11-jdk
```
### Build and run by make
```shell
cd funasr/runtime/java
# download java lib
make downjar
# compile
make buildwebsocket
# run client
make runclient
```
## Run java websocket client by shell
```shell
# full command refer to Makefile runclient
usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
[--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
Where:
--host <string>
(required) server-ip
--port <int>
(required) port
--audio_in <string>
(required) the wav or pcm file path
--num_threads <int>
thread number for test
--mode
asr mode, support "offline" "online" "2pass"
example:
FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
result json, example like:
{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"javatest"}
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example.

View File

@ -65,7 +65,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
n_total_length += snippet_time;
FunASRFreeResult(result);
}else{
LOG(ERROR) << ("No return data!\n");
LOG(ERROR) << wav_ids[i] << (": No return data!\n");
}
}
{

View File

@ -18,6 +18,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p
try{
m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
LOG(INFO) << "Successfully load model from " << punc_model;
}
catch (std::exception const &e) {
LOG(ERROR) << "Error when load punc onnx model: " << e.what();

View File

@ -33,6 +33,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
try {
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
exit(0);

View File

@ -13,7 +13,7 @@ def get_readme():
MODULE_NAME = 'funasr_onnx'
VERSION_NUM = '0.1.0'
VERSION_NUM = '0.1.1'
setuptools.setup(
name=MODULE_NAME,
@ -31,7 +31,7 @@ setuptools.setup(
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
"typeguard",
"typeguard==2.13.3",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"funasr",

View File

@ -3,7 +3,7 @@ generated certificate may not suitable for all browsers due to security concerns
```shell
### 1) Generate a private key
openssl genrsa -des3 -out server.key 1024
openssl genrsa -des3 -out server.key 2048
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
@ -14,4 +14,4 @@ openssl rsa -in server.key.org -out server.key
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
```
```

View File

@ -1,15 +1,21 @@
-----BEGIN CERTIFICATE-----
MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n
MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh
YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD
TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH
YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG
SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS
NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc
38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL
mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA
dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB
uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS
hXtozVCgkSJzX6uD
MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
-----END CERTIFICATE-----

View File

@ -1,15 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u
cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso
OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB
AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq
pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc
2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE
NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD
M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3
7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt
w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z
fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l
SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO
uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB
MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
-----END RSA PRIVATE KEY-----

View File

@ -91,7 +91,6 @@ class WebsocketClient {
using websocketpp::lib::placeholders::_1;
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
// m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
m_client.set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
@ -218,7 +217,7 @@ class WebsocketClient {
}
}
if (wait) {
LOG(INFO) << "wait.." << m_open;
// LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
@ -292,7 +291,7 @@ int main(int argc, char* argv[]) {
false, 1, "int");
TCLAP::ValueArg<int> is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
false, 0, "int");
false, 1, "int");
cmd.add(server_ip_);
cmd.add(port_);

View File

@ -38,6 +38,7 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@ -45,6 +46,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, Paraf
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
@ -54,6 +56,7 @@ from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@ -134,6 +137,7 @@ model_choices = ClassChoices(
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
sa_asr=SAASRModel,
),
type_check=FunASRModel,
default="asr",
@ -175,6 +179,27 @@ encoder_choices2 = ClassChoices(
type_check=AbsEncoder,
default="rnn",
)
asr_encoder_choices = ClassChoices(
"asr_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
spk_encoder_choices = ClassChoices(
"spk_encoder",
classes=dict(
resnet34_diar=ResNet34Diar,
),
default="resnet34_diar",
)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
@ -197,6 +222,7 @@ decoder_choices = ClassChoices(
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@ -329,6 +355,12 @@ class ASRTask(AbsTask):
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--max_spk_num",
type=int_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--seg_dict_file",
type=str,
@ -1495,3 +1527,123 @@ class ASRTransducerTask(ASRTask):
#assert check_return_type(model)
return model
class ASRTaskSAASR(ASRTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
# --asr_encoder and --asr_encoder_conf
asr_encoder_choices,
# --spk_encoder and --spk_encoder_conf
spk_encoder_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 5. Encoder
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
# 7. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=asr_encoder.output_size(),
**args.decoder_conf,
)
# 8. CTC
ctc = CTC(
odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
)
# import ipdb;ipdb.set_trace()
# 9. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("asr")
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
asr_encoder=asr_encoder,
spk_encoder=spk_encoder,
decoder=decoder,
ctc=ctc,
token_list=token_list,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

View File

@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_sa_asr import ESPnetASRModel
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@ -120,7 +120,7 @@ normalize_choices = ClassChoices(
model_choices = ClassChoices(
"model",
classes=dict(
asr=ESPnetASRModel,
asr=SAASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
@ -620,4 +620,4 @@ class ASRTask(AbsTask):
initialize(model, args.init)
assert check_return_type(model)
return model
return model

View File

@ -1 +1 @@
0.6.1
0.6.2