Add modular SA-ASR recipe for M2MeT2.0 (#831)

* add modular saasr

* update readme

* Delete train_paraformer.yaml

* update setup.py

* update setup.py

* update setup.py
This commit is contained in:
yhliang 2023-08-10 20:46:21 +08:00 committed by GitHub
parent ea2c102e61
commit 08ee9e6aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 2511 additions and 1 deletions

View File

@ -0,0 +1,103 @@
# Get Started
This is an official modular SA-ASR system used in M2MeT 2.0 challenge. We developed this system based on various pre-trained models after the challenge and reach the ***SOTA***(until 2023.8.9) performance on the AliMeeting *Test_2023* set. You can also transcribe your own dataset by preparing it into the specific format shown in
# Dependency
To run this receipe, you should install [Kaldi](https://github.com/kaldi-asr/kaldi) and set the `KALDI_ROOT` in `path.sh`.
```shell
export KALDI_ROOT=/your_kaldi_path
```
We use the [VBx](https://github.com/BUTSpeechFIT/VBx) to provide initial diarization result to SOND and [dscore](https://github.com/nryant/dscore.git) to compute the DER. You should clone them before running this receipe.
```shell
$ mkdir VBx && cd VBx
$ git init
$ git remote add origin https://github.com/BUTSpeechFIT/VBx.git
$ git config core.sparsecheckout true
$ echo "VBx/*" >> .git/info/sparse-checkout
$ git pull origin master
$ mv VBx/* .
$ cd ..
$ git clone https://github.com/nryant/dscore.git
```
We use the [pb_chime5](https://github.com/fgnt/pb_chime5) to perform GSS. So you should install the dependencies of this repo using the following command.
```shell
$ git clone https://github.com/fgnt/pb_chime5.git
$ cd pb_chime5
$ git submodule init
$ git submodule update
$ pip install -e pb_bss/
$ pip install -e .
```
# Infer on the AliMeeting Test_2023 set
We follow the workflow shown below.
<div align="left"><img src="figure/20230809161919.jpg" width="500"/>
First you should set the `DATA_SOURCE` in `path.sh` to the data path. Your data path should be organized as follow:
```shell
Test_2023_Ali_far_release
|—— audio_dir/
| |—— R1014_M1710.wav
| |—— R1014_M1750.wav
| |—— ...
|—— textgrid_dir/
| |—— R1014_M1710.textgrid
| |—— R1014_M1750.textgrid
| |—— ...
|—— wav.scp
|—— segments
```
Then you can do speaker diarization with following command.
```shell
$ bash run_diar.sh
```
After diarization, you can check the result at the last line of `data/Test_2023_Ali_far_sond/dia_outputs/dia_result`. You should get a DER about 1.51%.
When you get the similar diarization result with us, then you can do the WPE and GSS using the following command.
```shell
$ bash run_enh.sh 8
```
The number 8 should be replaced with the channel number of your dataset. Here we use the AliMeeting corpus which has 8 channels.
Finally, you can decode the processed audio with the pre-trained ASR model directly using the flollowing commands.
```shell
$ bash run_asr.sh --stage 0 --stop-stage 1
$ bash run_asr.sh --stage 3 --stop-stage 3
```
The ASR result is saved at `./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/decode_Test_2023_Ali_far_wpegss/text_cpcer`.
# Infer on the AliMeeting Test_2023 set after finetune
You can finetune the pre-trained ASR model with the AliMeeting train set to obtain a further reduction on the cpCER. To infer on the AliMeeting Test 2023 set after finetuning, you can run this commands after the train set is processed with WPE and GSS mentioned above.
```shell
$ bash run_asr.sh --stage 2 --stop-stage 3
```
# Infer with your own dataset
We also support infer with your own dataset. Your dataset should be organized as above. The `wav.scp` and `segments` file should format as:
```shell
# wav.scp
sessionA wav_path/wav_name_A.wav
sessionB wav_path/wav_name_B.wav
sessionC wav_path/wav_name_C.wav
...
# segments
sessionA-start_time-end_time sessionA start_time end_time
sessionB-start_time-end_time sessionA start_time end_time
sessionC-start_time-end_time sessionA start_time end_time
...
```
Then you should set the `DATA_SOURCE` and `DATA_NAME` in `path.sh`. The rest of the process is the same as [Infer on the AliMeeting Test_2023 set](#infer-on-the-alimeeting-test_2023-set).
# Result
| |VBx DER(%) | SOND DER(%)|cp-CER(%) |
|:---------------|:------------:|:------------:|----------:|
|before finetune | 16.87 | 1.51 | 10.18 |
|after finetune | 16.87 | 1.51 | |

View File

@ -0,0 +1,11 @@
# config for high-resolution MFCC features, intended for neural network training.
# Note: we keep all cepstra, so it has the same info as filterbank features,
# but MFCC is more easily compressible (because less correlated) which is why
# we prefer this method.
--use-energy=false # use average of log energy, not energy.
--sample-frequency=16000
--num-mel-bins=40
--num-ceps=40
--low-freq=40
--high-freq=-400

View File

@ -0,0 +1 @@
--norm-means=false --norm-vars=false

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

View File

@ -0,0 +1,103 @@
import editdistance
import sys
import os
from itertools import permutations
def load_transcripts(file_path):
trans_list = []
for one_line in open(file_path, "rt"):
meeting_id, trans = one_line.strip().split(" ")
trans_list.append((meeting_id.strip(), trans.strip()))
return trans_list
def calc_spk_trans(trans):
spk_trans_ = [x.strip() for x in trans.split("$")]
spk_trans = []
for i in range(len(spk_trans_)):
spk_trans.append((str(i), spk_trans_[i]))
return spk_trans
def calc_cer(ref_trans, hyp_trans):
ref_spk_trans = calc_spk_trans(ref_trans)
hyp_spk_trans = calc_spk_trans(hyp_trans)
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
errors, counts, permutes = [], [], []
min_error = 0
cost_dict = {}
for perm in permutations(range(num_spk)):
flag = True
p_err, p_count = 0, 0
for idx, p in enumerate(perm):
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
flag = False
break
cost_key = "{}-{}".format(idx, p)
if cost_key in cost_dict:
_e = cost_dict[cost_key]
else:
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
cost_dict[cost_key] = _e
if _e > min_error > 0:
flag = False
break
p_err += _e
p_count += len(ref_spk_trans[idx][1])
if flag:
if p_err < min_error or min_error == 0:
min_error = p_err
errors.append(p_err)
counts.append(p_count)
permutes.append(perm)
sd_cer = [(err, cnt, err/cnt, permute)
for err, cnt, permute in zip(errors, counts, permutes)]
best_rst = min(sd_cer, key=lambda x: x[2])
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
def main():
ref=sys.argv[1]
hyp=sys.argv[2]
result_path="/".join(hyp.split("/")[:-1]) + "/text_cpcer"
ref_list = load_transcripts(ref)
hyp_list = load_transcripts(hyp)
result_file = open(result_path,'w')
record_2_spk = [0, 0]
record_3_spk = [0, 0]
record_4_spk = [0, 0]
error, count = 0, 0
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
assert ref_id == hyp_id
mid = ref_id
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
error, count = error + dist, count + length
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
ref_spk = len(ref_trans.split("$"))
hyp_spk = len(hyp_trans.split("$"))
if ref_spk == 2:
record_2_spk[0] += dist
record_2_spk[1] += length
elif ref_spk == 3:
record_3_spk[0] += dist
record_3_spk[1] += length
else:
record_4_spk[0] += dist
record_4_spk[1] += length
print(record_2_spk[0]/record_2_spk[1]*100.0)
print(record_3_spk[0]/record_3_spk[1]*100.0)
print(record_4_spk[0]/record_4_spk[1]*100.0)
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
result_file.close()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,132 @@
import os
from funasr.utils.job_runner import MultiProcessRunnerV3
import numpy as np
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import median_filter
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("label_txt", type=str)
parser.add_argument("map_scp", type=str)
parser.add_argument("out_rttm", type=str)
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--chunk_len", type=int, default=1600)
parser.add_argument("--shift_len", type=int, default=400)
parser.add_argument("--ignore_len", type=int, default=5)
parser.add_argument("--smooth_size", type=int, default=7)
parser.add_argument("--vote_prob", type=float, default=0.5)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out_rttm)):
os.makedirs(os.path.dirname(args.out_rttm))
utt2labels = load_scp_as_list(args.label_txt, 'list')
utt2labels = sorted(utt2labels, key=lambda x: x[0])
meeting2map = load_scp_as_dict(args.map_scp)
meeting2labels = OrderedDict()
for utt_id, chunk_label in utt2labels:
mid = utt_id.split("-")[0]
if mid not in meeting2labels:
meeting2labels[mid] = []
meeting2labels[mid].append(chunk_label)
task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
return task_list, None, args
def post(self, result_list, args):
with open(args.out_rttm, "wt") as fd:
for results in result_list:
fd.writelines(results)
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def sample2ms(sample, sr=16000):
return int(float(sample) / sr * 100)
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
n_chunk = len(chunk_label_list)
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
weight = np.zeros((n_frame, 1), dtype=float)
for i in range(n_chunk):
raw_label = chunk_label_list[i]
for k in range(len(raw_label)):
if raw_label[k] == '<unk>':
raw_label[k] = raw_label[k-1] if k > 0 else '0'
chunk_multi_label = seq2arr(raw_label, n_spk)
chunk_len = chunk_multi_label.shape[0]
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
multi_labels = multi_labels / weight # normalizing vote
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
return multi_labels
def calc_spk_turns(label_arr, spk_list):
turn_list = []
length = label_arr.shape[0]
n_spk = label_arr.shape[1]
for k in range(n_spk):
if spk_list[k] == "None":
continue
in_utt = False
start = 0
for i in range(length):
if label_arr[i, k] == 1 and in_utt is False:
start = i
in_utt = True
if label_arr[i, k] == 0 and in_utt is True:
turn_list.append([spk_list[k], start, i - start])
in_utt = False
if in_utt:
turn_list.append([spk_list[k], start, length - start])
return turn_list
def smooth_multi_labels(multi_label, win_len):
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
return multi_label
def process(task_args):
_, task_list, _, args = task_args
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
results = []
for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
utt2map = load_scp_as_list(map_file_path, 'list')
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
org_multi_labels = np.zeros((org_len, args.n_spk))
for seg_id, [org_st, org_ed, st, ed] in utt2map:
org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
spk_turns = calc_spk_turns(org_multi_labels, spk_list)
spk_turns = sorted(spk_turns, key=lambda x: x[1])
for spk, st, dur in spk_turns:
# TODO: handle the leak of segments at the change points
if dur > args.ignore_len:
results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
return results
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,104 @@
import codecs
import sys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import numpy as np
import os
import soundfile
data_path = sys.argv[1]
segment_file_path = data_path + "/segments_nooverlap"
utt2spk_file_path = data_path + "/utt2spk_nooverlap"
wav_scp_path = data_path + "/wav.scp"
cluster_emb_dir = data_path + '/cluster_embedding/'
os.system("mkdir -p " + cluster_emb_dir)
cluster_profile_dir = data_path + '/cluster_profile_zeropadding16/'
os.system('mkdir -p ' + cluster_profile_dir)
utt2spk = {}
spk2seg = {}
with codecs.open(utt2spk_file_path, "r", "utf-8") as f1:
with codecs.open(segment_file_path, "r", "utf-8") as f2:
for line in f1.readlines():
uttid, spkid = line.strip().split(" ")
utt2spk[uttid] = spkid
for line in f2.readlines():
uttid, sessionid, stime, etime = line.strip().split(" ")
spkid = utt2spk[uttid]
if spkid not in spk2seg.keys():
spk2seg[spkid] = [(int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000))]
else:
spk2seg[spkid].append((int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000)))
inference_sv_pipline = pipeline(
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch',
device='gpu'
)
wav_dict = {}
with codecs.open(wav_scp_path, "r", "utf-8") as fi:
with codecs.open(data_path + "/cluster_embedding.scp", "w", "utf-8") as fo:
for line in fi.readlines():
sessionid, wav_path = line.strip().split()
wav_dict[sessionid] = wav_path
for spkid, segs in spk2seg.items():
sessionid = spkid.split("-")[0]
wav_path = wav_dict[sessionid]
wav = soundfile.read(wav_path)[0]
if wav.ndim == 2:
wav = wav[:, 0]
all_seg_embedding_list=[]
for seg in segs:
if seg[0] < wav.shape[0] - 0.5 * 16000:
if seg[1] > wav.shape[0]:
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: ])["spk_embedding"]
else:
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: seg[0] + seg[1]])["spk_embedding"]
all_seg_embedding_list.append(cur_seg_embedding)
all_seg_embedding = np.vstack(all_seg_embedding_list)
spk_embedding = np.mean(all_seg_embedding, axis=0)
np.save(cluster_emb_dir + spkid + '.npy', spk_embedding)
fo.write(spkid + ' ' + cluster_emb_dir + spkid + '.npy' + '\n')
session2embs = {}
with codecs.open(data_path + "/cluster_embedding.scp", "r", "utf-8") as fi:
with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "w", "utf-8") as fo:
for line in fi.readlines():
spkid, emb_path = line.strip().split(" ")
sessionid = spkid.split("-")[0]
if sessionid not in session2embs.keys():
session2embs[sessionid] = [emb_path]
else:
session2embs[sessionid].append(emb_path)
for sessionid, embs in session2embs.items():
emb_list = [np.load(x) for x in embs]
tmp = []
for i in range(len(emb_list) - 1):
flag = True
for j in range(i + 1, len(emb_list)):
cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
if cos_sim > 0.99:
flag = False
if flag:
tmp.append(emb_list[i][np.newaxis, :])
tmp.append(emb_list[-1][np.newaxis, :])
emb_list = tmp
# tmp = []
# for i in range(len(emb_list)):
# for emb in tmp:
# cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
# if cos_sim > 0.99:
# flag = False
# if flag:
# tmp.append(emb_list[i][np.newaxis, :])
# emb_list = tmp
for i in range(16 - len(emb_list)):
emb_list.append(np.zeros((1, 256)))
emb = np.concatenate(emb_list, axis=0)
save_path = cluster_profile_dir + sessionid + ".npy"
np.save(save_path, emb)
fo.write("%s %s\n" % (sessionid, save_path))

View File

@ -0,0 +1,34 @@
import os
import sys
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.msdatasets.audio.asr_dataset import ASRDataset
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
kwargs = dict(
model=params.model,
data_dir=ds_dict,
dataset_type=params.dataset_type,
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
from funasr.utils.modelscope_param import modelscope_args
params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
params.output_dir = "./checkpoint" # 模型保存路径
params.data_path = "./data" # 数据路径可以为modelscope中已上传数据也可以是本地数据
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 20 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
modelscope_finetune(params)

View File

@ -0,0 +1 @@
../../sa_asr/local/format_wav_scp.py

View File

@ -0,0 +1 @@
../../sa_asr/local/format_wav_scp.sh

View File

@ -0,0 +1,27 @@
from funasr.bin.diar_inference_launch import inference_launch
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'
def main():
diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pb"
input_dir = sys.argv[3] if len(sys.argv) > 3 else "./inputs"
output_dir = sys.argv[4] if len(sys.argv) > 4 else "./outputs"
data_path_and_name_and_type = [
(input_dir + "/wav.scp", "speech", "sound"),
(input_dir + "/profile.scp", "profile", "npy"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=16,
ngpu=1,
)
pipeline(data_path_and_name_and_type)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,58 @@
import argparse
import tqdm
import codecs
import textgrid
import pdb
class Segment(object):
def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.text = text
def change_stime(self, time):
self.stime = time
def change_etime(self, time):
self.etime = time
def main(args):
tg = textgrid.TextGrid.fromFile(args.input_textgrid_file)
segments = []
spk = {}
num_spk = 1
uttid = args.uttid
for i in range(tg.__len__()):
for j in range(tg[i].__len__()):
if tg[i][j].mark:
if tg[i].name not in spk:
spk[tg[i].name] = num_spk
num_spk += 1
segments.append(
Segment(
uttid,
spk[tg[i].name],
tg[i][j].minTime,
tg[i][j].maxTime,
tg[i][j].mark.strip(),
)
)
segments = sorted(segments, key=lambda x: x.stime)
rttm_file = codecs.open(args.output_rttm_file, "w", "utf-8")
for i in range(len(segments)):
fmt = "SPEAKER {:s} 1 {:.2f} {:.2f} <NA> <NA> {:s} <NA> <NA>"
#pdb.set_trace()
rttm_file.write(f"{fmt.format(segments[i].uttid, float(segments[i].stime), float(segments[i].etime) - float(segments[i].stime), str(segments[i].spkr))}\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Make rttm for true label")
parser.add_argument("--input_textgrid_file", required=True, help="The textgrid file")
parser.add_argument("--output_rttm_file", required=True, help="The output rttm file")
parser.add_argument("--uttid", required=True, help="The utt id of the file")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
class Segment(object):
def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.text = text
def change_stime(self, time):
self.stime = time
def change_etime(self, time):
self.etime = time
def get_args():
parser = argparse.ArgumentParser(description="process the textgrid files")
parser.add_argument("--path", type=str, required=True, help="textgrid path")
parser.add_argument("--label_path", type=str, required=True, help="label rttm file path")
parser.add_argument("--predict_path", type=str, required=True, help="predict rttm file path")
args = parser.parse_args()
return args
def main(args):
textgrid_flist = codecs.open(Path(args.path)/"uttid_textgrid.flist", "r", "utf-8")
# parse the textgrid file for each utterance
speaker2_uttidset = []
speaker3_uttidset = []
speaker4_uttidset = []
for line in textgrid_flist:
uttid ,textgrid_file = line.strip().split("\t")
tg = textgrid.TextGrid()
tg.read(textgrid_file)
num_speaker = len(tg)
if num_speaker ==2:
speaker2_uttidset.append(uttid)
elif num_speaker ==3:
speaker3_uttidset.append(uttid)
elif num_speaker ==4:
speaker4_uttidset.append(uttid)
textgrid_flist.close()
speaker2_id_label = codecs.open(Path(args.label_path) / "speaker2_id", "w", "utf-8")
speaker2_id_predict = codecs.open(Path(args.predict_path) / "speaker2_id", "w", "utf-8")
speaker3_id_label = codecs.open(Path(args.label_path) / "speaker3_id", "w", "utf-8")
speaker3_id_predict = codecs.open(Path(args.predict_path) / "speaker3_id", "w", "utf-8")
speaker4_id_label = codecs.open(Path(args.label_path) / "speaker4_id", "w", "utf-8")
speaker4_id_predict = codecs.open(Path(args.predict_path) / "speaker4_id", "w", "utf-8")
for i in range(len(speaker2_uttidset)):
speaker2_id_label.write("%s\n" % (args.label_path+"/"+speaker2_uttidset[i]+".rttm"))
speaker2_id_predict.write("%s\n" % (args.predict_path+"/"+speaker2_uttidset[i]+".rttm"))
for i in range(len(speaker3_uttidset)):
speaker3_id_label.write("%s\n" % (args.label_path+"/"+speaker3_uttidset[i]+".rttm"))
speaker3_id_predict.write("%s\n" % (args.predict_path+"/"+speaker3_uttidset[i]+".rttm"))
for i in range(len(speaker4_uttidset)):
speaker4_id_label.write("%s\n" % (args.label_path+"/"+speaker4_uttidset[i]+".rttm"))
speaker4_id_predict.write("%s\n" % (args.predict_path+"/"+speaker4_uttidset[i]+".rttm"))
speaker2_id_label.close()
speaker2_id_predict.close()
speaker3_id_label.close()
speaker3_id_predict.close()
speaker4_id_label.close()
speaker4_id_predict.close()
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,52 @@
import sys
import codecs
import zhconv
decode_result = sys.argv[1]
utt2spk_file = sys.argv[2]
merged_result = "/".join(decode_result.split("/")[:-1]) + "/text_merge"
utt2text = {}
utt2spk = {}
spk2texts = {}
spk2text = {}
meeting2text = {}
with codecs.open(decode_result, "r", "utf-8") as f1:
with codecs.open(utt2spk_file, "r", "utf-8") as f2:
for line in f1.readlines():
try:
line_list = line.strip().split()
uttid = line_list[0]
text = "".join(line_list[1:])
except:
continue
utt2text[uttid] = text
for line in f2.readlines():
uttid, spkid = line.strip().split()
utt2spk[uttid] = spkid
for utt, text in utt2text.items():
spk = utt2spk[utt]
stime = int(utt.split("-")[-2])
if spk in spk2texts.keys():
spk2texts[spk].append([stime, text])
else:
spk2texts[spk] = [[stime, text]]
for spk, texts in spk2texts.items():
texts = sorted(texts, key=lambda x: x[0])
text = "".join([x[1] for x in texts])
spk2text[spk] = text
with codecs.open(merged_result, "w", "utf-8") as f:
for spk, text in spk2text.items():
# meeting = spk.split("-")[2]
meeting = spk.split("-")[0]
if meeting in meeting2text.keys():
meeting2text[meeting] = meeting2text[meeting] + "$" + text
else:
meeting2text[meeting] = text
for meeting, text in meeting2text.items():
f.write("%s %s\n" % (meeting, text))

View File

@ -0,0 +1,140 @@
import sys
import pdb
import codecs
import os
input_segments_file = sys.argv[1]
input_utt2spk_file = sys.argv[2]
output_segments_file = sys.argv[3]
output_utt2spk_file = sys.argv[4]
threshold = sys.argv[5]
class Segment(object):
def __init__(self, baseid, spkid, meetingid, stime, etime, uttid=None):
self.baseid = baseid
self.spkid = spkid
self.meetingid = meetingid
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.uttid = uttid
self.dur = self.etime - self.stime
if self.uttid is None:
self.uttid = "%s-%s-%07d-%07d" % (
self.baseid,
self.spkid,
self.stime * 100,
self.etime * 100,
)
def cut(cur_max_end_time, seg_list, cur_seg, next_c):
global out_segment_dict
if next_c == len(seg_list):
single_stime = max(cur_max_end_time, cur_seg.stime)
single_etime = cur_seg.etime
if single_stime < single_etime and single_etime - single_stime > float(threshold):
# only save segment which duration more than threshold for sv's accuracy
if cur_seg.spkid not in out_segment_dict.keys():
out_segment_dict[cur_seg.spkid] = [
Segment(
cur_seg.baseid,
cur_seg.spkid,
cur_seg.meetingid,
single_stime,
single_etime,
)]
else:
out_segment_dict[cur_seg.spkid].append(
Segment(
cur_seg.baseid,
cur_seg.spkid,
cur_seg.meetingid,
single_stime,
single_etime,
)
)
else:
next_seg = seg_list[next_c]
single_stime = max(cur_max_end_time, cur_seg.stime)
single_etime = min(cur_seg.etime, next_seg.stime)
if single_stime < single_etime and single_etime - single_stime > float(threshold):
if cur_seg.spkid not in out_segment_dict.keys():
out_segment_dict[cur_seg.spkid] = [
Segment(
cur_seg.baseid,
cur_seg.spkid,
cur_seg.meetingid,
single_stime,
single_etime,
)]
else:
out_segment_dict[cur_seg.spkid].append(
Segment(
cur_seg.baseid,
cur_seg.spkid,
cur_seg.meetingid,
single_stime,
single_etime,
)
)
if cur_seg.etime > next_seg.etime:
cut(max(cur_max_end_time, next_seg.etime), seg_list, cur_seg, next_c + 1)
meeting2seg = {}
utt2spk = {}
i = 0
with codecs.open(input_utt2spk_file, "r", "utf-8") as f:
for line in f.readlines():
utt, spk = line.strip().split()
utt2spk[utt] = spk
with codecs.open(input_segments_file, "r", "utf-8") as f:
for line in f.readlines():
i += 1
uttid, meetingid, stime, etime = line.strip().split(" ")
spkid = utt2spk[uttid].split("-")[1]
baseid = meetingid
one_seg = Segment(baseid, spkid, meetingid, float(stime), float(etime))
if one_seg.meetingid not in meeting2seg.keys():
meeting2seg[one_seg.meetingid] = [one_seg]
else:
meeting2seg[one_seg.meetingid].append(one_seg)
out_segment_dict = {}
for k, v in meeting2seg.items():
meeting2seg[k] = sorted(v, key=lambda x: x.stime)
cur_max_end_time = 0
for i in range(len(v)):
cur_seg = meeting2seg[k][i]
cut(cur_max_end_time, meeting2seg[k], cur_seg, i + 1)
cur_max_end_time = max(cur_max_end_time, cur_seg.etime)
out_segment_list = []
for k, v in out_segment_dict.items():
out_segment_list.extend(out_segment_dict[k])
with codecs.open(output_segments_file, "w", "utf-8") as f_seg:
with codecs.open(output_utt2spk_file, "w", "utf-8") as f_utt2spk:
for out_seg in out_segment_list:
f_seg.write(
"%s %s %.2f %.2f\n"
% (
out_seg.uttid,
out_seg.meetingid,
out_seg.stime,
out_seg.etime,
)
)
f_utt2spk.write(
"%s %s-%s\n"
% (
out_seg.uttid,
out_seg.baseid,
out_seg.spkid,
)
)

View File

@ -0,0 +1,78 @@
import soundfile
import os
import sys
import codecs
import numpy as np
import pdb
segment_file_path = sys.argv[1]
wav_scp_file_path = sys.argv[2]
data_path = sys.argv[3]
wav_save_path = data_path + "/wav/"
os.system("mkdir -p " + wav_save_path)
pos_path = data_path + "/pos_map/"
os.system("mkdir -p " + pos_path)
wav_dict = {}
seg2time = {}
seg2time_new = {}
session2profile = {}
with codecs.open(wav_scp_file_path, "r", "utf-8") as f:
for line in f.readlines():
sessionid, wav_path = line.strip().split()
wav_dict[sessionid] = wav_path
with codecs.open(segment_file_path, "r", "utf-8") as f:
for line in f.readlines():
_, sessionid, stime, etime = line.strip().split()
if sessionid not in seg2time.keys():
seg2time[sessionid] = [(int(16000 * float(stime)), int(16000 * float(etime)))]
else:
seg2time[sessionid].append((int(16000 * float(stime)), int(16000 * float(etime))))
with codecs.open(data_path + "/map.scp", "w", "utf-8") as f1:
for sessionid, seg_times in seg2time.items():
seg2time_new[sessionid] = []
last_time = 0
with codecs.open(pos_path + sessionid + ".pos", "w", "utf-8") as f2:
for seg_time in seg_times:
tmp = seg_time[0] - last_time
cur_seg = (seg_time[0] - tmp, seg_time[1] - tmp)
seg2time_new[sessionid].append((seg_time[0] - last_time, seg_time[1] - last_time))
last_time = cur_seg[1]
f2.write("%s-%07d-%07d %d %d %d %d\n" % (sessionid, seg_time[0]/160, seg_time[1]/160, seg_time[0], seg_time[1], cur_seg[0], cur_seg[1]))
f1.write("%s %s\n" % (sessionid, pos_path + sessionid + ".pos"))
with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "r", "utf-8") as f:
for line in f.readlines():
session, path = line.strip().split()
session2profile[session] = path
with codecs.open(data_path + "/wav.scp", "w", "utf-8") as f1:
with codecs.open(data_path + "/profile.scp", "w", "utf-8") as f2:
for sessionid, wav_path in wav_dict.items():
wav = soundfile.read(wav_path)[0]
if wav.ndim == 2:
wav = wav[:, 0]
seg_list = [wav[seg[0]: seg[1]] for seg in seg2time[sessionid]]
wav_new = np.concatenate(seg_list, axis=0)
cur_time = 0
flag = True
while flag:
start = cur_time
end = cur_time + 256000
if end < wav_new.shape[0]:
cur_wav = wav_new[start: end]
else:
cur_wav = wav_new[start: ]
flag = False
cur_time = cur_time + 64000
wav_name = "%s-%07d_%07d.wav" % (sessionid, start/160, end/160)
soundfile.write(wav_save_path + wav_name, cur_wav, 16000)
f1.write("%s %s\n" % (wav_name, wav_save_path + wav_name))
f2.write("%s %s\n" % (wav_name, session2profile[sessionid]))

View File

@ -0,0 +1,29 @@
import codecs
import sys
rttm_file_path = sys.argv[1]
segment_file_path = sys.argv[2]
mode = sys.argv[3] # 0 for diarization, 1 for asr
meeting2spk = {}
with codecs.open(rttm_file_path, "r", "utf-8") as fi:
with codecs.open(segment_file_path + "/segments", "w", "utf-8") as f1:
with codecs.open(segment_file_path + "/utt2spk", "w", "utf-8") as f2:
for line in fi.readlines():
_, sessionid, _, stime, dur, _, _, spkid, _, _ = line.strip().split(" ")
if float(dur) < 0.3:
continue
uttid = "%s-%07d-%07d" % (sessionid, int(float(stime) * 100), int(float(stime) * 100 + float(dur) * 100))
spkid = "%s-%s" % (sessionid, spkid)
if int(mode) == 0:
f1.write("%s %s %.2f %.2f\n" % (uttid, sessionid, float(stime), float(stime) + float(dur)))
f2.write("%s %s\n" % (uttid, spkid))
elif int(mode) == 1:
f1.write("%s %s %.2f %.2f\n" % (uttid, spkid, float(stime), float(stime) + float(dur)))
f2.write("%s %s\n" % (uttid, spkid))
else:
exit("mode only support 0 or 1!")

View File

@ -0,0 +1,139 @@
#!/usr/bin/env python
# -- coding: UTF-8
import argparse
import codecs
import glob
import logging
import os
from nara_wpe.utils import stft, istft
import numpy as np
import scipy.io.wavfile as wf
from tqdm import tqdm
from test_gss import *
def get_parser():
parser = argparse.ArgumentParser("Doing GSS based enhancement.")
parser.add_argument(
"--wav-scp",
type=str,
required=True,
help="Wav scp file for enhancement.",
)
parser.add_argument(
"--segments",
type=str,
required=True,
help="Wav scp file for enhancement.",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output directory of GSS enhanced data.",
)
return parser
def wfread(f):
fs, data = wf.read(f)
if data.dtype == np.int16:
data = np.float32(data) / 32768
return data, fs
def wfwrite(z, fs, store_path):
tmpwav = np.int16(z * 32768)
wf.write(store_path, fs, tmpwav)
def main():
args = get_parser().parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
stft_window, stft_shift = 512, 256
gss = GSS(iterations=20, iterations_post=1)
bf = Beamformer("mvdrSouden_ban", "mask_mul")
with codecs.open(args.wav_scp, "r") as handle:
lines_content = handle.readlines()
wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
cnt = 0
session2spk2dur = {}
with codecs.open(args.segments, "r") as handle:
for line in handle.readlines():
uttid, spkid, stime, etime = line.strip().split(" ")
sessionid = spkid.split("-")[0]
if sessionid not in session2spk2dur.keys():
session2spk2dur[sessionid] = {}
if spkid not in session2spk2dur[sessionid].keys():
session2spk2dur[sessionid][spkid] = []
session2spk2dur[sessionid][spkid].append((float(stime), float(etime)))
# import pdb;pdb.set_trace()
for wav_idx in tqdm(range(len(wav_lines)), leave=True, desc="0"):
# get wav files from scp file
file_list = wav_lines[wav_idx].split(" ")
sessionid, wav_list = file_list[0], file_list[1:]
signal_list = []
time_activity = []
cnt += 1
logging.info("Processing {}: {}".format(cnt, wav_list[0]))
# read all wavs
for wav in wav_list:
data, fs = wfread(wav)
signal_list.append(data)
try:
obstft = np.stack(signal_list, axis=0)
except:
minlen = min([len(s) for s in signal_list])
obstft = np.stack([s[:minlen] for s in signal_list])
wavlen = obstft.shape[1]
obstft = stft(obstft, stft_window, stft_shift)
# get activated timestamps and frequencies
speaker_list = []
for spk, dur in session2spk2dur[sessionid].items():
speaker_list.append(spk.split("-")[-1])
time_activity.append(get_time_activity(dur, wavlen, fs))
time_activity.append([True] * wavlen)
frequency_activity = get_frequency_activity(
time_activity, stft_window, stft_shift
)
# import pdb;pdb.set_trace()
# generate mask
masks = gss(obstft, frequency_activity)
masks_bak = masks
for i in range(masks.shape[0] - 1):
target_mask = masks[i]
distortion_mask = np.sum(np.delete(masks, i, axis=0), axis=0)
xhat = bf(obstft, target_mask=target_mask, distortion_mask=distortion_mask)
xhat = istft(xhat, stft_window, stft_shift)
audio_dir = "/".join(wav_list[0].split("/")[:-1])
store_path = (
wav_list[0]
.replace(audio_dir, args.output_dir)
.replace(".wav", "-{}.wav".format(speaker_list[i]))
)
if not os.path.exists(os.path.split(store_path)[0]):
os.makedirs(os.path.split(store_path)[0], exist_ok=True)
logging.info("Save wav file {}.".format(store_path))
wfwrite(xhat, fs, store_path)
masks = masks_bak
if __name__ == "__main__":
main()

View File

@ -0,0 +1,153 @@
#!/usr/bin/env python
# _*_ coding: UTF-8 _*_
import argparse
import codecs
import os
import logging
from multiprocessing import Pool
import numpy as np
import scipy.io.wavfile as wf
from nara_wpe.utils import istft, stft
from nara_wpe.wpe import wpe_v8 as wpe
def wpe_worker(
wav_scp,
audio_dir="",
output_dir="",
channel=0,
processing_id=None,
processing_num=None,
):
sampling_rate = 16000
iterations = 5
stft_options = dict(
size=512,
shift=128,
window_length=None,
fading=True,
pad=True,
symmetric_window=False,
)
with codecs.open(wav_scp, "r") as handle:
lines_content = handle.readlines()
wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
for wav_idx in range(len(wav_lines)):
if processing_id is None:
processing_token = True
else:
if wav_idx % processing_num == processing_id:
processing_token = True
else:
processing_token = False
if processing_token:
wav_list = wav_lines[wav_idx].split(" ")
file_exist = True
for wav_path in wav_list:
file_exist = file_exist and os.path.exists(
wav_path.replace(audio_dir, output_dir)
)
if not file_exist:
break
if not file_exist:
logging.info("wait to process {} : {}".format(wav_idx, wav_list[0]))
signal_list = []
for f in wav_list:
_, data = wf.read(f)
data = data[:, channel - 1]
if data.dtype == np.int16:
data = np.float32(data) / 32768
signal_list.append(data)
min_len = len(signal_list[0])
max_len = len(signal_list[0])
for i in range(1, len(signal_list)):
min_len = min(min_len, len(signal_list[i]))
max_len = max(max_len, len(signal_list[i]))
if min_len != max_len:
for i in range(len(signal_list)):
signal_list[i] = signal_list[i][:min_len]
y = np.stack(signal_list, axis=0)
Y = stft(y, **stft_options).transpose(2, 0, 1)
Z = wpe(Y, iterations=iterations, statistics_mode="full").transpose(
1, 2, 0
)
z = istft(Z, size=stft_options["size"], shift=stft_options["shift"])
for d in range(len(signal_list)):
store_path = wav_list[d].replace(audio_dir, output_dir)
if not os.path.exists(os.path.split(store_path)[0]):
os.makedirs(os.path.split(store_path)[0], exist_ok=True)
tmpwav = np.int16(z[d, :] * 32768)
wf.write(store_path, sampling_rate, tmpwav)
else:
logging.info("file exist {} : {}".format(wav_idx, wav_list[0]))
return None
def wpe_manager(
wav_scp, processing_num=1, audio_dir="", output_dir="", channel=1
):
if processing_num > 1:
pool = Pool(processes=processing_num)
for i in range(processing_num):
pool.apply_async(
wpe_worker,
kwds={
"wav_scp": wav_scp,
"processing_id": i,
"processing_num": processing_num,
"audio_dir": audio_dir,
"output_dir": output_dir,
},
)
pool.close()
pool.join()
else:
wpe_worker(wav_scp, audio_dir=audio_dir, output_dir=output_dir, channel=channel)
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser("run_wpe")
parser.add_argument(
"--wav-scp",
type=str,
required=True,
help="Path pf wav scp file",
)
parser.add_argument(
"--audio-dir",
type=str,
required=True,
help="Directory of input audio files",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output directory of WPE enhanced audio files",
)
parser.add_argument(
"--channel",
type=str,
required=True,
help="Channel number of input audio",
)
parser.add_argument("--nj", type=int, default="1", help="number of process")
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info("wavfile={}".format(args.wav_scp))
logging.info("processingnum={}".format(args.nj))
wpe_manager(
wav_scp=args.wav_scp,
processing_num=args.nj,
audio_dir=args.audio_dir,
output_dir=args.output_dir,
channel=int(args.channel)
)

View File

@ -0,0 +1,58 @@
import argparse
import os
def read_segments_file(segments_file):
utt2segments = dict()
with open(segments_file, "r") as fr:
lines = fr.readlines()
for line in lines:
parts = line.strip().split()
segment_utt_id, utt_id, start, end = parts[0], parts[1], float(parts[2]), float(parts[3])
if utt_id not in utt2segments:
utt2segments[utt_id] = []
utt2segments[utt_id].append((segment_utt_id, start, end))
return utt2segments
def write_label(label_file, label_list):
with open(label_file, "w") as fw:
for (start, end) in label_list:
fw.write(f"{start} {end} sp\n")
fw.flush()
def write_label_scp_file(label_scp_file, label_scp: dict):
with open(label_scp_file, "w") as fw:
for (utt_id, label_path) in label_scp.items():
fw.write(f"{utt_id} {label_path}\n")
fw.flush()
def main(args):
input_segments = args.input_segments
label_path = args.label_path
output_label_scp_file = args.output_label_scp_file
utt2segments = read_segments_file(input_segments)
print(f"Collect {len(utt2segments)} utt2segments in file {input_segments}")
result_label_scp = dict()
for utt_id in utt2segments.keys():
segment_list = utt2segments[utt_id]
cur_label_path = os.path.join(label_path, f"{utt_id}.lab")
write_label(cur_label_path, label_list=[(i1, i2) for (_, i1, i2) in segment_list])
result_label_scp[utt_id] = cur_label_path
write_label_scp_file(output_label_scp_file, result_label_scp)
print(f"Write {len(result_label_scp)} labels")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Make the lab file for segments")
parser.add_argument("--input_segments", required=True, help="The input segments file")
parser.add_argument("--label_path", required=True, help="The label_path to save file.lab")
parser.add_argument("--output_label_scp_file", required=True, help="The output label.scp file")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,225 @@
#!/bin/bash
# Copyright 2016-17 Vimal Manohar
# 2017 Nagendra Kumar Goel
# Apache 2.0.
# This script does nnet3-based speech activity detection given an input
# kaldi data directory and outputs a segmented kaldi data directory.
# This script can also do music detection and other similar segmentation
# using appropriate options such as --output-name output-music.
set -e
set -o pipefail
set -u
if [ -f ./path.sh ]; then . ./path.sh; fi
#export PATH=/usr/local/cuda-10.0/bin:$PATH
#export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:$LD_LIBRARY_PATH
#echo $PATH
#echo $LD_LIBRARY_PATH
affix= # Affix for the segmentation
nj=32
cmd=run.pl
stage=-1
# Feature options (Must match training)
mfcc_config=conf/mfcc_hires.conf
feat_affix= # Affix for the type of feature used
output_name=output # The output node in the network
sad_name=sad # Base name for the directory storing the computed loglikes
# Can be music for music detection
segmentation_name=segmentation # Base name for the directory doing segmentation
# Can be segmentation_music for music detection
# SAD network config
iter=final # Model iteration to use
# Contexts must ideally match training for LSTM models, but
# may not necessarily for stats components
extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training)
extra_right_context=0
extra_left_context_initial=-1
extra_right_context_final=-1
frames_per_chunk=150
# Decoding options
graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0"
acwt=0.3
# These <from>_in_<to>_weight represent the fraction of <from> probability
# to transfer to <to> class.
# e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3
transform_probs_opts=""
# Postprocessing options
segment_padding=0.2 # Duration (in seconds) of padding added to segments
min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included
# This is before any padding. Segments shorter than this duration will be removed.
# This is an alternative to --min-speech-duration above.
merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many
# seconds. The segments are only merged if their boundaries are touching.
# This is after padding by --segment-padding seconds.
# 0 means do not merge. Use 'inf' to not limit the duration.
echo $*
. utils/parse_options.sh
if [ $# -ne 5 ]; then
echo "This script does nnet3-based speech activity detection given an input kaldi "
echo "data directory and outputs an output kaldi data directory."
echo "See script for details of the options to be supplied."
echo "Usage: $0 <src-data-dir> <sad-nnet-dir> <mfcc-dir> <work-dir> <out-data-dir>"
echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 \\"
echo " mfcc_hires exp/segmentation_sad_snr/nnet_tdnn_j_n4 data/ami_sdm1_dev"
echo ""
echo "Options: "
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
echo " --nj <num-job> # number of parallel jobs to run."
echo " --stage <stage> # stage to do partial re-run from."
echo " --convert-data-dir-to-whole <true|false> # If true, the input data directory is "
echo " # first converted to whole data directory (i.e. whole recordings) "
echo " # and segmentation is done on that."
echo " # If false, then the original segments are "
echo " # retained and they are split into sub-segments."
echo " --output-name <name> # The output node in the network"
echo " --extra-left-context <context|0> # Set to some large value, typically 40 for LSTM (must match training)"
echo " --extra-right-context <context|0> # For BLSTM or statistics pooling"
exit 1
fi
src_data_dir=$1 # The input data directory that needs to be segmented.
# If convert_data_dir_to_whole is true, any segments in that will be ignored.
sad_nnet_dir=$2 # The SAD neural network
mfcc_dir=$3 # The directory to store the features
dir=$4 # Work directory
data_dir=$5 # The output data directory will be ${data_dir}_seg
affix=${affix:+_$affix}
feat_affix=${feat_affix:+_$feat_affix}
data_id=`basename $data_dir`
sad_dir=${dir}/${sad_name}${affix}_${data_id}${feat_affix}
seg_dir=${dir}/${segmentation_name}${affix}_${data_id}${feat_affix}
# test_data_dir=data/${data_id}${feat_affix}
test_data_dir=${src_data_dir}
###############################################################################
## Forward pass through the network network and dump the log-likelihoods.
###############################################################################
frame_subsampling_factor=1
if [ -f $sad_nnet_dir/frame_subsampling_factor ]; then
frame_subsampling_factor=$(cat $sad_nnet_dir/frame_subsampling_factor)
fi
mkdir -p $dir
if [ $stage -le 1 ]; then
if [ "$(readlink -f $sad_nnet_dir)" != "$(readlink -f $dir)" ]; then
cp $sad_nnet_dir/cmvn_opts $dir || exit 1
fi
########################################################################
## Initialize neural network for decoding using the output $output_name
########################################################################
if [ ! -z "$output_name" ] && [ "$output_name" != output ]; then
$cmd $dir/log/get_nnet_${output_name}.log \
nnet3-copy --edits="rename-node old-name=$output_name new-name=output" \
$sad_nnet_dir/$iter.raw $dir/${iter}_${output_name}.raw || exit 1
iter=${iter}_${output_name}
else
if ! diff $sad_nnet_dir/$iter.raw $dir/$iter.raw; then
cp $sad_nnet_dir/$iter.raw $dir/
fi
fi
echo ${test_data_dir}
steps/nnet3/compute_output.sh --nj $nj --cmd "$cmd" \
--iter ${iter} \
--extra-left-context $extra_left_context \
--extra-right-context $extra_right_context \
--extra-left-context-initial $extra_left_context_initial \
--extra-right-context-final $extra_right_context_final \
--frames-per-chunk $frames_per_chunk --apply-exp true \
--frame-subsampling-factor $frame_subsampling_factor \
${test_data_dir} $dir $sad_dir || exit 1
fi
###############################################################################
## Prepare FST we search to make speech/silence decisions.
###############################################################################
utils/data/get_utt2dur.sh --nj $nj --cmd "$cmd" $test_data_dir || exit 1
frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) || exit 1
graph_dir=${dir}/graph_${output_name}
if [ $stage -le 2 ]; then
mkdir -p $graph_dir
# 1 for silence and 2 for speech
cat <<EOF > $graph_dir/words.txt
<eps> 0
silence 1
speech 2
EOF
$cmd $graph_dir/log/make_graph.log \
steps/segmentation/internal/prepare_sad_graph.py $graph_opts \
--frame-shift=$(perl -e "print $frame_shift * $frame_subsampling_factor") - \| \
fstcompile --isymbols=$graph_dir/words.txt --osymbols=$graph_dir/words.txt '>' \
$graph_dir/HCLG.fst
fi
###############################################################################
## Do Viterbi decoding to create per-frame alignments.
###############################################################################
post_vec=$sad_nnet_dir/post_${output_name}.vec
if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then
if [ ! -f $sad_nnet_dir/post_${output_name}.txt ]; then
echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. "
echo "Re-run the corresponding stage in the training script possibly "
echo "with --compute-average-posteriors=true or compute the priors "
echo "from the training labels"
exit 1
else
post_vec=$sad_nnet_dir/post_${output_name}.txt
fi
fi
mkdir -p $seg_dir
if [ $stage -le 3 ]; then
steps/segmentation/internal/get_transform_probs_mat.py \
--priors="$post_vec" $transform_probs_opts > $seg_dir/transform_probs.mat
steps/segmentation/decode_sad.sh --acwt $acwt --cmd "$cmd" \
--nj $nj \
--transform "$seg_dir/transform_probs.mat" \
$graph_dir $sad_dir $seg_dir
fi
###############################################################################
## Post-process segmentation to create kaldi data directory.
###############################################################################
if [ $stage -le 4 ]; then
steps/segmentation/post_process_sad_to_segments.sh \
--segment-padding $segment_padding --min-segment-dur $min_segment_dur \
--merge-consecutive-max-dur $merge_consecutive_max_dur \
--cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \
${test_data_dir} ${seg_dir} ${seg_dir}
fi
if [ $stage -le 5 ]; then
utils/data/subsegment_data_dir.sh ${test_data_dir} ${seg_dir}/segments \
${data_dir}_seg
fi
echo "$0: Created output segmented kaldi data directory in ${data_dir}_seg"
exit 0

View File

@ -0,0 +1,141 @@
import io
import functools
import logging
# import soundfile as sf
import numpy as np
import matplotlib
import matplotlib.pylab as plt
# from IPython.display import display, Audio
from nara_wpe.utils import stft, istft
from pb_bss.distribution import CACGMMTrainer
from pb_bss.evaluation import InputMetrics, OutputMetrics
from dataclasses import dataclass
# from beamforming_wrapper import beamform_mvdr_souden_from_masks
from pb_chime5.utils.numpy_utils import segment_axis_v2
from textgrid_processor import read_textgrid_from_file
def get_time_activity(dur_list, wavlen, sr):
time_activity = [False] * wavlen
for dur in dur_list:
xmax = int(dur[1] * sr)
xmin = int(dur[0] * sr)
if xmax > wavlen:
continue
for i in range(xmin, xmax):
time_activity[i] = True
logging.info("Num of actived samples {}".format(time_activity.count(True)))
return time_activity
def get_frequency_activity(
time_activity,
stft_window_length,
stft_shift,
stft_fading=True,
stft_pad=True,
):
time_activity = np.asarray(time_activity)
if stft_fading:
pad_width = np.array([(0, 0)] * time_activity.ndim)
pad_width[-1, :] = stft_window_length - stft_shift # Consider fading
time_activity = np.pad(time_activity, pad_width, mode="constant")
return segment_axis_v2(
time_activity,
length=stft_window_length,
shift=stft_shift,
end="pad" if stft_pad else "cut",
).any(axis=-1)
@dataclass
class Beamformer:
type: str
postfilter: str
def __call__(self, Obs, target_mask, distortion_mask, debug=False):
bf = self.type
if bf == "mvdrSouden_ban":
from pb_chime5.speech_enhancement.beamforming_wrapper import (
beamform_mvdr_souden_from_masks,
)
X_hat = beamform_mvdr_souden_from_masks(
Y=Obs,
X_mask=target_mask,
N_mask=distortion_mask,
ban=True,
)
elif bf == "ch0":
X_hat = Obs[0]
elif bf == "sum":
X_hat = np.sum(Obs, axis=0)
else:
raise NotImplementedError(bf)
if self.postfilter is None:
pass
elif self.postfilter == "mask_mul":
X_hat = X_hat * target_mask
else:
raise NotImplementedError(self.postfilter)
return X_hat
@dataclass
class GSS:
iterations: int = 20
iterations_post: int = 0
verbose: bool = True
# use_pinv: bool = False
# stable: bool = True
def __call__(self, Obs, acitivity_freq=None, debug=False):
initialization = np.asarray(acitivity_freq, dtype=np.float64)
initialization = np.where(initialization == 0, 1e-10, initialization)
initialization = initialization / np.sum(initialization, keepdims=True, axis=0)
initialization = np.repeat(initialization[None, ...], 257, axis=0)
source_active_mask = np.asarray(acitivity_freq, dtype=bool)
source_active_mask = np.repeat(source_active_mask[None, ...], 257, axis=0)
cacGMM = CACGMMTrainer()
if debug:
learned = []
all_affiliations = []
F = Obs.shape[-1]
T = Obs.T.shape[-2]
for f in range(F):
if self.verbose:
if f % 50 == 0:
logging.info(f"{f}/{F}")
# T: Consider end of signal.
# This should not be nessesary, but activity is for inear and not for
# array.
cur = cacGMM.fit(
y=Obs.T[f, ...],
initialization=initialization[f, ..., :T],
iterations=self.iterations,
source_activity_mask=source_active_mask[f, ..., :T],
)
affiliation = cur.predict(
Obs.T[f, ...],
source_activity_mask=source_active_mask[f, ..., :T],
)
all_affiliations.append(affiliation)
posterior = np.array(all_affiliations).transpose(1, 2, 0)
return posterior

View File

@ -0,0 +1,316 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import codecs
class TextGrid(object):
def __init__(
self,
file_type="",
object_class="",
xmin=0.0,
xmax=0.0,
tiers_status="",
tiers=[],
):
self.file_type = file_type
self.object_class = object_class
self.xmin = xmin
self.xmax = xmax
self.tiers_status = tiers_status
self.tiers = tiers
if self.xmax < self.xmin:
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
def cutoff(self, xstart=None, xend=None):
if xstart is None:
xstart = self.xmin
if xend is None:
xend = self.xmax
if xend < xstart:
raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
new_xmax = xend - xstart + self.xmin
new_xmin = self.xmin
new_tiers = []
for tier in self.tiers:
new_tiers.append(tier.cutoff(xstart=xstart, xend=xend))
return TextGrid(
file_type=self.file_type,
object_class=self.object_class,
xmin=new_xmin,
xmax=new_xmax,
tiers_status=self.tiers_status,
tiers=new_tiers,
)
class Tier(object):
def __init__(self, tier_class="", name="", xmin=0.0, xmax=0.0, intervals=[]):
self.tier_class = tier_class
self.name = name
self.xmin = xmin
self.xmax = xmax
self.intervals = intervals
if self.xmax < self.xmin:
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
def cutoff(self, xstart=None, xend=None):
if xstart is None:
xstart = self.xmin
if xend is None:
xend = self.xmax
if xend < xstart:
raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
bias = xstart - self.xmin
new_xmax = xend - bias
new_xmin = self.xmin
new_intervals = []
for interval in self.intervals:
if interval.xmax <= xstart or interval.xmin >= xend:
pass
elif interval.xmin < xstart:
new_intervals.append(
Interval(
xmin=new_xmin, xmax=interval.xmax - bias, text=interval.text
)
)
elif interval.xmax > xend:
new_intervals.append(
Interval(
xmin=interval.xmin - bias, xmax=new_xmax, text=interval.text
)
)
else:
new_intervals.append(
Interval(
xmin=interval.xmin - bias,
xmax=interval.xmax - bias,
text=interval.text,
)
)
return Tier(
tier_class=self.tier_class,
name=self.name,
xmin=new_xmin,
xmax=new_xmax,
intervals=new_intervals,
)
class Interval(object):
def __init__(self, xmin=0.0, xmax=0.0, text=""):
self.xmin = xmin
self.xmax = xmax
self.text = text
if self.xmax < self.xmin:
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
def read_textgrid_from_file(filepath):
with codecs.open(filepath, "r", encoding="utf-8") as handle:
lines = handle.readlines()
if lines[-1] == "\r\n":
lines = lines[:-1]
assert "File type" in lines[0], "error line 0, {}".format(lines[0])
file_type = (
lines[0]
.split("=")[1]
.replace(" ", "")
.replace('"', "")
.replace("\r", "")
.replace("\n", "")
)
assert "Object class" in lines[1], "error line 1, {}".format(lines[1])
object_class = (
lines[1]
.split("=")[1]
.replace(" ", "")
.replace('"', "")
.replace("\r", "")
.replace("\n", "")
)
assert lines[2] == "\r\n", "error line 2, {}".format(lines[2])
assert "xmin" in lines[3], "error line 3, {}".format(lines[3])
xmin = float(
lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert "xmax" in lines[4], "error line 4, {}".format(lines[4])
xmax = float(
lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert "tiers?" in lines[5], "error line 5, {}".format(lines[5])
tiers_status = (
lines[5].split("?")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert "size" in lines[6], "error line 6, {}".format(lines[6])
size = int(
lines[6].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert lines[7] == "item []:\r\n", "error line 7, {}".format(lines[7])
tier_start = []
for item_idx in range(size):
tier_start.append(lines.index(" " * 4 + "item [{}]:\r\n".format(item_idx + 1)))
tier_end = tier_start[1:] + [len(lines)]
tiers = []
for tier_idx in range(size):
tiers.append(
read_tier_from_lines(
tier_lines=lines[tier_start[tier_idx] + 1 : tier_end[tier_idx]]
)
)
return TextGrid(
file_type=file_type,
object_class=object_class,
xmin=xmin,
xmax=xmax,
tiers_status=tiers_status,
tiers=tiers,
)
def read_tier_from_lines(tier_lines):
assert "class" in tier_lines[0], "error line 0, {}".format(tier_lines[0])
tier_class = (
tier_lines[0]
.split("=")[1]
.replace(" ", "")
.replace('"', "")
.replace("\r", "")
.replace("\n", "")
)
assert "name" in tier_lines[1], "error line 1, {}".format(tier_lines[1])
name = (
tier_lines[1]
.split("=")[1]
.replace(" ", "")
.replace('"', "")
.replace("\r", "")
.replace("\n", "")
)
assert "xmin" in tier_lines[2], "error line 2, {}".format(tier_lines[2])
xmin = float(
tier_lines[2].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert "xmax" in tier_lines[3], "error line 3, {}".format(tier_lines[3])
xmax = float(
tier_lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
assert "intervals: size" in tier_lines[4], "error line 4, {}".format(tier_lines[4])
intervals_num = int(
tier_lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
)
# handle unformatted case
# R12_S203204205_C09_I1_Near_203.TextGrid
# R12_S203204205_C09_I1_Near_205.TextGrid
if tier_lines[-1] == "\n":
tier_lines = tier_lines[:-1]
if len(tier_lines[5:]) == intervals_num * 5:
intervals = []
for intervals_idx in range(intervals_num):
assert tier_lines[
5 + 5 * intervals_idx + 0
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
assert tier_lines[
5 + 5 * intervals_idx + 1
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
intervals.append(
read_interval_from_lines(
interval_lines=tier_lines[
7 + 5 * intervals_idx : 10 + 5 * intervals_idx
]
)
)
elif len(tier_lines[5:]) == intervals_num * 4:
# handle unformatted case
# R12_S203204205_C09_I1_Near_203.TextGrid
# R12_S203204205_C09_I1_Near_204.TextGrid
# R12_S203204205_C09_I1_Near_205.TextGrid
intervals = []
for intervals_idx in range(intervals_num):
assert tier_lines[
5 + 4 * intervals_idx + 0
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
intervals.append(
read_interval_from_lines(
interval_lines=tier_lines[
6 + 4 * intervals_idx : 9 + 4 * intervals_idx
]
)
)
else:
import pdb
pdb.set_trace()
raise ValueError(
"error lines {} % {} != 0".format(len(tier_lines[5:]), intervals_num)
)
return Tier(
tier_class=tier_class, name=name, xmin=xmin, xmax=xmax, intervals=intervals
)
def read_interval_from_lines(interval_lines):
assert len(interval_lines) == 3, "error lines"
assert "xmin" in interval_lines[0], "error line 0, {}".format(interval_lines[0])
xmin = float(
interval_lines[0]
.split("=")[1]
.replace(" ", "")
.replace("\r", "")
.replace("\n", "")
)
assert "xmax" in interval_lines[1], "error line 1, {}".format(interval_lines[1])
xmax = float(
interval_lines[1]
.split("=")[1]
.replace(" ", "")
.replace("\r", "")
.replace("\n", "")
)
assert "text" in interval_lines[2], "error line 2, {}".format(interval_lines[2])
text = (
interval_lines[2]
.split("=")[1]
.replace(" ", "")
.replace('"', "")
.replace("\r", "")
.replace("\n", "")
)
return Interval(xmin=xmin, xmax=xmax, text=text)

View File

@ -0,0 +1,14 @@
export FUNASR_DIR=$PWD/../../..
export KALDI_ROOT=/Your_Kaldi_root
export DATA_SOURCE=/Your_data_path
export DATA_NAME=Test_2023_Ali_far
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH

View File

@ -0,0 +1,152 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="4,5,6,7"
gpu_num=4
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
finetune=true
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=2
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="data" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="1.0"
stage=0
stop_stage=1
# feature configuration
feats_dim=80
nj=64
# exp tag
tag="finetune"
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=Train_Ali_far_wpegss
valid_set=Test_Ali_far_wpegss
test_sets="${DATA_NAME}_wpegss"
asr_config=conf/train_paraformer.yaml
model_dir="$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
pretrain_model_dir=./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
inference_config=$pretrain_model_dir/decoding.yaml
token_list=$pretrain_model_dir/tokens.txt
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
if ${finetune}; then
inference_asr_model=./checkpoint/valid.acc.ave.pb
finetune_tag="_finetune"
else
inference_asr_model=$pretrain_model_dir/model.pb
finetune_tag=""
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [ -L ./utils ]; then
unlink ./utils
ln -s ../../aishell/transformer/utils
else
ln -s ../../aishell/transformer/utils
fi
fi
# Download Model
world_size=$gpu_num # run on one machine
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Download Model"
if [ ! -d $pretrain_model_dir ]; then
git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git
fi
fi
# ASR Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: ASR Training"
python -m torch.distributed.launch \
--nproc_per_node $gpu_num local/finetune.py
fi
# Testing Stage
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Inference"
for dset in ${test_sets}; do
_dir="$pretrain_model_dir/decode_${dset}${finetune_tag}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="./data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--cmvn_file $pretrain_model_dir/am.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config $pretrain_model_dir/config.yaml \
--asr_model_file $inference_asr_model \
--output_dir "${_logdir}"/output.JOB \
--mode paraformer \
${_opts}
for f in token token_int score text; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python local/merge_spk_text.py ${_dir}/text ${_data}/utt2spk
python local/compute_cpcer.py ${_data}/text_merge ${_dir}/text_merge
echo "cpCER is saved at ${_dir}/text_cpcer"
done
fi

View File

@ -0,0 +1,233 @@
#!/usr/bin/env bash
set -e
set -o pipefail
. path.sh || exit 1
train_cmd=utils/run.pl
# data path
data_source_dir=$DATA_SOURCE
textgrid_dir=$data_source_dir/textgrid_dir/
wav_dir=$data_source_dir/audio_dir/
# work path
work_dir=./data/${DATA_NAME}_sc/
sad_dir=$work_dir/sad_part/
sad_work_dir=$sad_dir/exp/
sad_result_dir=$sad_dir/sad
dia_dir=$work_dir/dia_part/
dia_vad_dir=$dia_dir/vad/
dia_rttm_dir=$dia_dir/rttm/
dia_emb_dir=$dia_dir/embedding/
dia_rtt_label_dir=$dia_dir/label_rttm/
dia_result_dir=$dia_dir/result_DER/
sond_work_dir=./data/${DATA_NAME}_sond/
asr_work_dir=./data/${DATA_NAME}_wpegss/org/
mkdir -p $work_dir || exit 1;
mkdir -p $sad_dir || exit 1;
mkdir -p $sad_work_dir || exit 1;
mkdir -p $sad_result_dir || exit 1;
mkdir -p $dia_dir || exit 1;
mkdir -p $dia_vad_dir || exit 1;
mkdir -p $dia_rttm_dir || exit 1;
mkdir -p $dia_emb_dir || exit 1;
mkdir -p $dia_rtt_label_dir || exit 1;
mkdir -p $dia_result_dir || exit 1;
mkdir -p $sond_work_dir || exit 1;
mkdir -p $asr_work_dir || exit 1;
stage=0
stop_stage=9
nj=4
sm_size=83
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# Check the installtion of kaldi
if [ -L ./steps ]; then
unlink ./steps
else
ln -s $KALDI_ROOT/egs/wsj/s5/steps || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
fi
if [ -L ./utils ]; then
unlink ./utils
else
ln -s $KALDI_ROOT/egs/wsj/s5/utils || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
fi
fi
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# Prepare the AliMeeting data
echo "Prepare Alimeeting data"
find $wav_dir -name "*\.wav" > $work_dir/wavlist
sort $work_dir/wavlist > $work_dir/tmp
cp $work_dir/tmp $work_dir/wavlist
awk -F '/' '{print $NF}' $work_dir/wavlist | awk -F '.' '{print $1}' > $work_dir/uttid
paste -d " " $work_dir/uttid $work_dir/wavlist > $work_dir/wav.scp
paste -d " " $work_dir/uttid $work_dir/uttid > $work_dir/utt2spk
cp $work_dir/utt2spk $work_dir/spk2utt
cp $work_dir/uttid $work_dir/text
sad_feat=$sad_dir/feat/mfcc
cp $work_dir/wav.scp $sad_dir
cp $work_dir/utt2spk $sad_dir
cp $work_dir/spk2utt $sad_dir
cp $work_dir/text $sad_dir
utils/fix_data_dir.sh $sad_dir
## first we extract the feature for sad model
steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \
--mfcc-config conf/mfcc_hires.conf \
$sad_dir $sad_dir/make_mfcc $sad_feat
fi
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# Do Speech Activity Detectation
echo "Do SAD"
./utils/split_data.sh $sad_dir $nj
## do the segmentations
local/segmentation/detect_speech_activity.sh --nj $nj --stage 0 \
--cmd "$train_cmd" $sad_dir exp/segmentation_1a/tdnn_stats_sad_1a/ \
$sad_dir/feat/mfcc $sad_work_dir $sad_result_dir
fi
if [ $stage -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do Speaker Embedding Extractor"
cp $work_dir/wav.scp $dia_dir
python local/segment_to_lab.py --input_segments $sad_dir/sad_seg/segments \
--label_path $dia_vad_dir \
--output_label_scp_file $dia_dir/label.scp ||exit 1;
./utils/split_data.sh $work_dir $nj
${train_cmd} JOB=1:${nj} $dia_dir/exp/extract_embedding.JOB.log \
python VBx/predict.py --in-file-list $work_dir/split${nj}/JOB/text \
--in-lab-dir $dia_dir/vad \
--in-wav-dir $wav_dir \
--out-ark-fn $dia_emb_dir/embedding_out.JOB.ark \
--out-seg-fn $dia_emb_dir/embedding_out.JOB.seg \
--weights VBx/models/ResNet101_16kHz/nnet/final.onnx \
--backend onnx
echo "success"
fi
if [ $stage -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# The Speaker Embedding Cluster
echo "Do the Speaker Embedding Cluster"
# The meeting data is long so that the cluster is a little bit slow
${train_cmd} JOB=1:${nj} $dia_dir/exp/cluster.JOB.log \
python VBx/vbhmm.py --init AHC+VB \
--out-rttm-dir $dia_rttm_dir \
--xvec-ark-file $dia_emb_dir/embedding_out.JOB.ark \
--segments-file $dia_emb_dir/embedding_out.JOB.seg \
--xvec-transform VBx/models/ResNet101_16kHz/transform.h5 \
--plda-file VBx/models/ResNet101_16kHz/plda \
--threshold 0.14 \
--lda-dim 128 \
--Fa 0.3 \
--Fb 17 \
--loopP 0.99
fi
if [ $stage -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Process textgrid to obtain rttm label"
find -L $textgrid_dir -iname "*.TextGrid" > $work_dir/textgrid.flist
sort $work_dir/textgrid.flist > $work_dir/tmp
cp $work_dir/tmp $work_dir/textgrid.flist
paste $work_dir/uttid $work_dir/textgrid.flist > $work_dir/uttid_textgrid.flist
while read text_file
do
text_grid=`echo $text_file | awk '{print $1}'`
text_grid_path=`echo $text_file | awk '{print $2}'`
python local/make_textgrid_rttm.py --input_textgrid_file $text_grid_path \
--uttid $text_grid \
--output_rttm_file $dia_rtt_label_dir/${text_grid}.rttm
done < $work_dir/uttid_textgrid.flist
if [ -f "$dia_rtt_label_dir/all.rttm" ]; then
rm -f $dia_rtt_label_dir/all.rttm
fi
cat $dia_rtt_label_dir/*.rttm > $dia_rtt_label_dir/all.rttm
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Get VBx DER result"
find $dia_rtt_label_dir -name "*.rttm" > $dia_rtt_label_dir/ref.scp
find $dia_rttm_dir -name "*.rttm" > $dia_rttm_dir/sys.scp
if [ -f "$dia_rttm_dir/all.rttm" ]; then
rm -f $dia_rttm_dir/all.rttm
fi
cat $dia_rttm_dir/*.rttm > $dia_rttm_dir/all.rttm
collar_set="0 0.25"
python local/meeting_speaker_number_process.py --path=$work_dir \
--label_path=$dia_rtt_label_dir --predict_path=$dia_rttm_dir
speaker_number="2 3 4"
for weight_collar in $collar_set;
do
# all meeting
python dscore/score.py --collar $weight_collar \
-R $dia_rtt_label_dir/ref.scp -S $dia_rttm_dir/sys.scp > $dia_result_dir/speaker_all_DER_overlaps_${weight_collar}.log
# 2,3,4 speaker meeting
for speaker_count in $speaker_number;
do
python dscore/score.py --collar $weight_collar \
-R $dia_rtt_label_dir/speaker${speaker_count}_id -S $dia_rttm_dir/speaker${speaker_count}_id > $dia_result_dir/speaker_${speaker_count}_DER_overlaps_${weight_collar}.log
done
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Downloading Pre-trained model..."
mkdir ./SOND
cd ./SOND
git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pb
cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pb
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
cd ..
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Prepare data for sond"
cp $work_dir/wav.scp $sond_work_dir
# convert rttm to segments
python local/rttm2segments.py $dia_rttm_dir/all.rttm $sond_work_dir 0
# remove the overlapped part
python local/remove_overlap.py $sond_work_dir/segments $sond_work_dir/utt2spk \
$sond_work_dir/segments_nooverlap $sond_work_dir/utt2spk_nooverlap 0.3
# extract speaker profile from the filtered segments file
python local/extract_profile_from_segments.py $sond_work_dir
# segment data to 16s
python local/resegment_data.py \
$data_source_dir/segments \
$data_source_dir/wav.scp \
$sond_work_dir
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
echo "Diarization with SOND"
python local/infer_sond.py SOND/sond.yaml SOND/sond.pb $sond_work_dir $sond_work_dir/dia_outputs
python local/convert_label_to_rttm.py \
$sond_work_dir/dia_outputs/labels.txt \
$sond_work_dir/map.scp \
$sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
--ignore_len 10 --no_pbar --smooth_size ${sm_size} \
--vote_prob 0.5 --n_spk 16
python dscore/score.py \
-r $dia_rtt_label_dir/all.rttm \
-s $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
--collar 0.25 &> $sond_work_dir/dia_outputs/dia_result
# convert rttm to segments
python local/rttm2segments.py $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm $asr_work_dir 1
fi

View File

@ -0,0 +1,114 @@
#!/usr/bin/env bash
set -e
set -o pipefail
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SECONDS=0
# general configuration
stage=1
stop_stage=3
nj=10
log "$0 $*"
. utils/parse_options.sh
. ./path.sh || exit 1
train_cmd=utils/run.pl
data_source_dir=$DATA_SOURCE
audio_dir=$data_source_dir/audio_dir
output_wpe_dir=$data_source_dir/wpe_audio_dir
output_gss_dir=$data_source_dir/gss_audio_dir
asr_data_path=./data/${DATA_NAME}_wpegss
channel=$1
log "Start Speech Enhancement."
if [ ! -L ./utils ]; then
ln -s ./pb_chime5/pb_bss
fi
# WPE
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1: Start WPE."
for ch in `seq ${channel}`; do
mkdir -p ${output_wpe_dir}_${ch}/log/
# split wav.scp
find $audio_dir/ -name "*.wav" > ${output_wpe_dir}_${ch}/wav.scp
arr=""
for i in `seq ${nj}`; do
arr="$arr ${output_wpe_dir}_${ch}/log/wav.${i}.scp"
done
split_scp.pl ${output_wpe_dir}_${ch}/wav.scp $arr
# do wpe
for n in `seq ${nj}`; do
cat <<-EOF >${output_wpe_dir}_${ch}/log/wpe.${n}.sh
python local/run_wpe.py \
--wav-scp ${output_wpe_dir}_${ch}/log/wav.${n}.scp \
--audio-dir ${audio_dir} \
--output-dir ${output_wpe_dir}_${ch} \
--ch $ch
EOF
done
chmod a+x ${output_wpe_dir}_${ch}/log/wpe.*.sh
${train_cmd} JOB=1:${nj} ${output_wpe_dir}_${ch}/log/wpe.JOB.log \
${output_wpe_dir}_${ch}/log/wpe.JOB.sh
done
fi
# GSS
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2: Start GSS"
if [ ! -d pb_chime5/ ]; then
log "Please install pb_chime5 by local/install_pb_chime5.sh"
exit 1
fi
mkdir -p $output_gss_dir/log
# split wpe.scp
for i in `seq ${channel}`; do
find ${output_wpe_dir}_${i}/ -name "*.wav" > $output_gss_dir/tmp${i}
done
awk -F '/' '{print($NF)}' $output_gss_dir/tmp1 | cut -d "." -f1 > $output_gss_dir/tmp
arr=""
for i in `seq ${channel}`; do
arr="$arr $output_gss_dir/tmp${i}"
done
paste -d " " $output_gss_dir/tmp $arr > $output_gss_dir/wpe.scp
rm -f $output_gss_dir/tmp*
arr=""
for i in `seq ${nj}`; do
arr="$arr $output_gss_dir/log/wpe.${i}.scp"
done
split_scp.pl $output_gss_dir/wpe.scp $arr
# do gss
for n in `seq ${nj}`; do
cat <<-EOF >${output_gss_dir}/log/gss.${n}.sh
python local/run_gss.py \
--wav-scp ${output_gss_dir}/log/wpe.${n}.scp \
--segments $asr_data_path/org/segments \
--output-dir ${output_gss_dir}
EOF
done
chmod a+x ${output_gss_dir}/log/gss.*.sh
${train_cmd} JOB=1:${nj} ${output_gss_dir}/log/gss.JOB.log \
${output_gss_dir}/log/gss.JOB.sh
fi
# Prepare data for ASR
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: Preparing data for ASR"
find $output_gss_dir -name "*.wav" > $asr_data_path/org/wav_list
awk -F '/' '{print($NF)}' $asr_data_path/org/wav_list | sed 's/\.wav//g' > $asr_data_path/org/uttid
paste -d " " $asr_data_path/org/uttid $asr_data_path/org/wav_list > $asr_data_path/org/wav.scp
bash local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format wav --segments $asr_data_path/org/segments \
"$asr_data_path/org/wav.scp" "$asr_data_path"
fi
log "End speech enhancement"

View File

@ -19,8 +19,11 @@ requirements = {
"soundfile>=0.12.1",
"h5py>=2.10.0",
"kaldiio>=2.17.0",
"kaldi-io==0.9.8",
"torch_complex",
"nltk>=3.4.5",
"onnxruntime"
"numexpr"
# ASR
"sentencepiece",
"jieba",
@ -32,6 +35,8 @@ requirements = {
"editdistance>=0.5.2",
"tensorboard",
"g2p",
"nara_wpe",
"Cython",
# PAI
"oss2",
"edit-distance",
@ -123,4 +128,4 @@ setup(
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
)