mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add speaker-attributed ASR task for alimeeting
This commit is contained in:
parent
49f13908de
commit
af6740a220
79
egs/alimeeting/sa-asr/README.md
Normal file
79
egs/alimeeting/sa-asr/README.md
Normal file
@ -0,0 +1,79 @@
|
||||
# Get Started
|
||||
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
|
||||
To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))
|
||||
There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
|
||||
Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
|
||||
```shell
|
||||
dataset
|
||||
|—— Eval_Ali_far
|
||||
|—— Eval_Ali_near
|
||||
|—— Test_Ali_far
|
||||
|—— Test_Ali_near
|
||||
|—— Train_Ali_far
|
||||
|—— Train_Ali_near
|
||||
```
|
||||
There are 18 stages in `run.sh`:
|
||||
```shell
|
||||
stage 1 - 5: Data preparation and processing.
|
||||
stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
|
||||
stage 7 - 9: Language model training (Optional).
|
||||
stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
|
||||
stage 12: SA-ASR training.
|
||||
stage 13 - 18: Inference and evaluation.
|
||||
```
|
||||
Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
|
||||
```shell
|
||||
data/Test_2023_Ali_far
|
||||
|—— wav.scp
|
||||
|—— wav_raw.scp
|
||||
|—— segments
|
||||
|—— utt2spk
|
||||
|—— spk2utt
|
||||
```
|
||||
There are 4 stages in `run_m2met_2023_infer.sh`:
|
||||
```shell
|
||||
stage 1: Data preparation and processing.
|
||||
stage 2: Generate speaker profiles for inference.
|
||||
stage 3: Inference.
|
||||
stage 4: Generation of SA-ASR results required for final submission.
|
||||
```
|
||||
# Format of Final Submission
|
||||
Finally, you need to submit a file called `text_spk_merge` with the following format:
|
||||
```shell
|
||||
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
|
||||
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
|
||||
...
|
||||
```
|
||||
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
|
||||
# Baseline Results
|
||||
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
|
||||
<table>
|
||||
<tr >
|
||||
<td rowspan="2"></td>
|
||||
<td colspan="2">SI-CER(%)</td>
|
||||
<td colspan="2">cpCER(%)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Eval</td>
|
||||
<td>Test</td>
|
||||
<td>Eval</td>
|
||||
<td>Test</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>oracle profile</td>
|
||||
<td>31.93</td>
|
||||
<td>32.75</td>
|
||||
<td>48.56</td>
|
||||
<td>53.33</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>cluster profile</td>
|
||||
<td>31.94</td>
|
||||
<td>32.77</td>
|
||||
<td>55.49</td>
|
||||
<td>58.17</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# Reference
|
||||
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.
|
||||
@ -475,7 +475,9 @@ if ! "${skip_data_prep}"; then
|
||||
fi
|
||||
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
|
||||
if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
|
||||
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
|
||||
fi
|
||||
|
||||
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
|
||||
_opts=
|
||||
@ -568,8 +570,11 @@ if ! "${skip_data_prep}"; then
|
||||
|
||||
# generate uttid
|
||||
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
|
||||
# filter utt2spk_all_fifo
|
||||
python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
|
||||
|
||||
if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
|
||||
# filter utt2spk_all_fifo
|
||||
python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
|
||||
fi
|
||||
done
|
||||
|
||||
# shellcheck disable=SC2002
|
||||
@ -585,7 +590,7 @@ if ! "${skip_data_prep}"; then
|
||||
echo "<blank>" > ${token_list}
|
||||
echo "<s>" >> ${token_list}
|
||||
echo "</s>" >> ${token_list}
|
||||
local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
|
||||
utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
|
||||
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
|
||||
num_token=$(cat ${token_list} | wc -l)
|
||||
echo "<unk>" >> ${token_list}
|
||||
@ -603,6 +608,7 @@ if ! "${skip_data_prep}"; then
|
||||
python local/process_text_id.py ${data_feats}/${dset}
|
||||
log "Successfully generate ${data_feats}/${dset}/text_id_train"
|
||||
# generate oracle_embedding from single-speaker audio segment
|
||||
log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
|
||||
python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
|
||||
log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
|
||||
# generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
|
||||
@ -615,6 +621,7 @@ if ! "${skip_data_prep}"; then
|
||||
fi
|
||||
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
|
||||
if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
|
||||
log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
|
||||
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
|
||||
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
|
||||
fi
|
||||
|
||||
@ -449,7 +449,7 @@ if ! "${skip_data_prep}"; then
|
||||
_opts+="--segments data/${dset}/segments "
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
|
||||
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
@ -467,7 +467,7 @@ if ! "${skip_data_prep}"; then
|
||||
mkdir -p "profile_log"
|
||||
for dset in "${test_sets}"; do
|
||||
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
|
||||
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
|
||||
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
|
||||
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
|
||||
done
|
||||
fi
|
||||
@ -1,157 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
def compute_wer(ref_file,
|
||||
hyp_file,
|
||||
cer_detail_file):
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
hyp_dict = {}
|
||||
ref_dict = {}
|
||||
with open(hyp_file, 'r') as hyp_reader:
|
||||
for line in hyp_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
hyp_dict[key] = value
|
||||
with open(ref_file, 'r') as ref_reader:
|
||||
for line in ref_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
ref_dict[key] = value
|
||||
|
||||
cer_detail_writer = open(cer_detail_file, 'w')
|
||||
for hyp_key in hyp_dict:
|
||||
if hyp_key in ref_dict:
|
||||
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
cer_detail_writer.write('\n')
|
||||
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
|
||||
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
|
||||
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
|
||||
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp,
|
||||
ref):
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
def print_cer_detail(rst):
|
||||
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
|
||||
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
|
||||
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
|
||||
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 4:
|
||||
print("usage : python compute-wer.py test.ref test.hyp test.wer")
|
||||
sys.exit(0)
|
||||
|
||||
ref_file = sys.argv[1]
|
||||
hyp_file = sys.argv[2]
|
||||
cer_detail_file = sys.argv[3]
|
||||
compute_wer(ref_file, hyp_file, cer_detail_file)
|
||||
@ -63,20 +63,20 @@ else
|
||||
fi
|
||||
|
||||
|
||||
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
|
||||
<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
|
||||
local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
|
||||
|
||||
if [[ -f ${srcdir}/segments ]]; then
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
|
||||
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
|
||||
local/apply_map.pl -f 2 "${destdir}"/reco_map | \
|
||||
awk -v factor="${factor}" \
|
||||
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
|
||||
>"${destdir}"/segments
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
@ -84,13 +84,13 @@ if [[ -f ${srcdir}/segments ]]; then
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
|
||||
local/apply_map.pl -f 1 "${destdir}"/reco_map \
|
||||
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
|
||||
fi
|
||||
|
||||
else # no segments->wav indexed by utterance.
|
||||
if [[ -f ${srcdir}/wav.scp ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
@ -101,16 +101,16 @@ else # no segments->wav indexed by utterance.
|
||||
fi
|
||||
|
||||
if [[ -f ${srcdir}/text ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
|
||||
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
|
||||
fi
|
||||
if [[ -f ${srcdir}/spk2gender ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
|
||||
local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
|
||||
fi
|
||||
if [[ -f ${srcdir}/utt2lang ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
|
||||
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
|
||||
fi
|
||||
|
||||
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
|
||||
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
|
||||
|
||||
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
|
||||
local/validate_data_dir.sh --no-feats --no-text "${destdir}"
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
in_f = sys.argv[1]
|
||||
out_f = sys.argv[2]
|
||||
|
||||
|
||||
with open(in_f, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
with open(out_f, "w", encoding="utf-8") as f:
|
||||
for line in lines:
|
||||
outs = line.strip().split(" ", 1)
|
||||
if len(outs) == 2:
|
||||
idx, text = outs
|
||||
text = re.sub("</s>", "", text)
|
||||
text = re.sub("<s>", "", text)
|
||||
text = re.sub("@@", "", text)
|
||||
text = re.sub("@", "", text)
|
||||
text = re.sub("<unk>", "", text)
|
||||
text = re.sub(" ", "", text)
|
||||
text = re.sub("\$", "", text)
|
||||
text = text.lower()
|
||||
else:
|
||||
idx = outs[0]
|
||||
text = " "
|
||||
|
||||
text = [x for x in text]
|
||||
text = " ".join(text)
|
||||
out = "{} {}\n".format(idx, text)
|
||||
f.write(out)
|
||||
@ -8,7 +8,6 @@ set -o pipefail
|
||||
ngpu=4
|
||||
device="0,1,2,3"
|
||||
|
||||
#stage 1 creat both near and far
|
||||
stage=1
|
||||
stop_stage=18
|
||||
|
||||
@ -22,7 +22,7 @@ inference_config=conf/decode_asr_rnn.yaml
|
||||
lm_config=conf/train_lm_transformer.yaml
|
||||
use_lm=false
|
||||
use_wordlm=false
|
||||
./asr_local_infer.sh \
|
||||
./asr_local_m2met_2023_infer.sh \
|
||||
--device ${device} \
|
||||
--ngpu ${ngpu} \
|
||||
--stage ${stage} \
|
||||
|
||||
@ -94,7 +94,7 @@ class Speech2Text:
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend=='wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
else:
|
||||
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
@ -147,13 +147,6 @@ class Speech2Text:
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
|
||||
@ -89,7 +89,7 @@ class Speech2Text:
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend=='wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
else:
|
||||
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
@ -142,13 +142,6 @@ class Speech2Text:
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
|
||||
@ -97,7 +97,7 @@ class NllLoss(nn.Module):
|
||||
normalize_length=False,
|
||||
criterion=nn.NLLLoss(reduction='none'),
|
||||
):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
"""Construct an NllLoss object."""
|
||||
super(NllLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
Loading…
Reference in New Issue
Block a user