add speaker-attributed ASR task for alimeeting

This commit is contained in:
smohan-speech 2023-05-08 16:13:23 +08:00
parent 49f13908de
commit af6740a220
11 changed files with 108 additions and 226 deletions

View File

@ -0,0 +1,79 @@
# Get Started
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))
There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
```shell
dataset
|—— Eval_Ali_far
|—— Eval_Ali_near
|—— Test_Ali_far
|—— Test_Ali_near
|—— Train_Ali_far
|—— Train_Ali_near
```
There are 18 stages in `run.sh`:
```shell
stage 1 - 5: Data preparation and processing.
stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
stage 7 - 9: Language model training (Optional).
stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
stage 12: SA-ASR training.
stage 13 - 18: Inference and evaluation.
```
Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
```shell
data/Test_2023_Ali_far
|—— wav.scp
|—— wav_raw.scp
|—— segments
|—— utt2spk
|—— spk2utt
```
There are 4 stages in `run_m2met_2023_infer.sh`:
```shell
stage 1: Data preparation and processing.
stage 2: Generate speaker profiles for inference.
stage 3: Inference.
stage 4: Generation of SA-ASR results required for final submission.
```
# Format of Final Submission
Finally, you need to submit a file called `text_spk_merge` with the following format:
```shell
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
...
```
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
# Baseline Results
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
<table>
<tr >
<td rowspan="2"></td>
<td colspan="2">SI-CER(%)</td>
<td colspan="2">cpCER(%)</td>
</tr>
<tr>
<td>Eval</td>
<td>Test</td>
<td>Eval</td>
<td>Test</td>
</tr>
<tr>
<td>oracle profile</td>
<td>31.93</td>
<td>32.75</td>
<td>48.56</td>
<td>53.33</td>
</tr>
<tr>
<td>cluster profile</td>
<td>31.94</td>
<td>32.77</td>
<td>55.49</td>
<td>58.17</td>
</tr>
</table>
# Reference
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 44134417.

View File

@ -475,7 +475,9 @@ if ! "${skip_data_prep}"; then
fi
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
fi
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
_opts=
@ -568,8 +570,11 @@ if ! "${skip_data_prep}"; then
# generate uttid
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
# filter utt2spk_all_fifo
python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
# filter utt2spk_all_fifo
python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
fi
done
# shellcheck disable=SC2002
@ -585,7 +590,7 @@ if ! "${skip_data_prep}"; then
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
@ -603,6 +608,7 @@ if ! "${skip_data_prep}"; then
python local/process_text_id.py ${data_feats}/${dset}
log "Successfully generate ${data_feats}/${dset}/text_id_train"
# generate oracle_embedding from single-speaker audio segment
log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
# generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
@ -615,6 +621,7 @@ if ! "${skip_data_prep}"; then
fi
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
fi

View File

@ -449,7 +449,7 @@ if ! "${skip_data_prep}"; then
_opts+="--segments data/${dset}/segments "
fi
# shellcheck disable=SC2086
scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
@ -467,7 +467,7 @@ if ! "${skip_data_prep}"; then
mkdir -p "profile_log"
for dset in "${test_sets}"; do
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
done
fi

View File

@ -1,157 +0,0 @@
import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
def compute_wer_by_line(hyp,
ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@ -63,20 +63,20 @@ else
fi
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
local/apply_map.pl -f 2 "${destdir}"/reco_map | \
awk -v factor="${factor}" \
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
>"${destdir}"/segments
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@ -84,13 +84,13 @@ if [[ -f ${srcdir}/segments ]]; then
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
local/apply_map.pl -f 1 "${destdir}"/reco_map \
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
fi
else # no segments->wav indexed by utterance.
if [[ -f ${srcdir}/wav.scp ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@ -101,16 +101,16 @@ else # no segments->wav indexed by utterance.
fi
if [[ -f ${srcdir}/text ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
local/validate_data_dir.sh --no-feats --no-text "${destdir}"

View File

@ -1,32 +0,0 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = re.sub("\$", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@ -8,7 +8,6 @@ set -o pipefail
ngpu=4
device="0,1,2,3"
#stage 1 creat both near and far
stage=1
stop_stage=18

View File

@ -22,7 +22,7 @@ inference_config=conf/decode_asr_rnn.yaml
lm_config=conf/train_lm_transformer.yaml
use_lm=false
use_wordlm=false
./asr_local_infer.sh \
./asr_local_m2met_2023_infer.sh \
--device ${device} \
--ngpu ${ngpu} \
--stage ${stage} \

View File

@ -94,7 +94,7 @@ class Speech2Text:
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@ -147,13 +147,6 @@ class Speech2Text:
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type

View File

@ -89,7 +89,7 @@ class Speech2Text:
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@ -142,13 +142,6 @@ class Speech2Text:
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type

View File

@ -97,7 +97,7 @@ class NllLoss(nn.Module):
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
"""Construct an LabelSmoothingLoss object."""
"""Construct an NllLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx