mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Add modular SA-ASR recipe for M2MeT2.0 (#831)
* add modular saasr * update readme * Delete train_paraformer.yaml * update setup.py * update setup.py * update setup.py
This commit is contained in:
parent
ea2c102e61
commit
08ee9e6aac
103
egs/alimeeting/modular_sa_asr/README.md
Normal file
103
egs/alimeeting/modular_sa_asr/README.md
Normal file
@ -0,0 +1,103 @@
|
||||
# Get Started
|
||||
This is an official modular SA-ASR system used in M2MeT 2.0 challenge. We developed this system based on various pre-trained models after the challenge and reach the ***SOTA***(until 2023.8.9) performance on the AliMeeting *Test_2023* set. You can also transcribe your own dataset by preparing it into the specific format shown in
|
||||
|
||||
# Dependency
|
||||
|
||||
To run this receipe, you should install [Kaldi](https://github.com/kaldi-asr/kaldi) and set the `KALDI_ROOT` in `path.sh`.
|
||||
```shell
|
||||
export KALDI_ROOT=/your_kaldi_path
|
||||
```
|
||||
|
||||
We use the [VBx](https://github.com/BUTSpeechFIT/VBx) to provide initial diarization result to SOND and [dscore](https://github.com/nryant/dscore.git) to compute the DER. You should clone them before running this receipe.
|
||||
```shell
|
||||
$ mkdir VBx && cd VBx
|
||||
$ git init
|
||||
$ git remote add origin https://github.com/BUTSpeechFIT/VBx.git
|
||||
$ git config core.sparsecheckout true
|
||||
$ echo "VBx/*" >> .git/info/sparse-checkout
|
||||
$ git pull origin master
|
||||
$ mv VBx/* .
|
||||
$ cd ..
|
||||
$ git clone https://github.com/nryant/dscore.git
|
||||
```
|
||||
|
||||
We use the [pb_chime5](https://github.com/fgnt/pb_chime5) to perform GSS. So you should install the dependencies of this repo using the following command.
|
||||
```shell
|
||||
$ git clone https://github.com/fgnt/pb_chime5.git
|
||||
$ cd pb_chime5
|
||||
$ git submodule init
|
||||
$ git submodule update
|
||||
$ pip install -e pb_bss/
|
||||
$ pip install -e .
|
||||
```
|
||||
|
||||
# Infer on the AliMeeting Test_2023 set
|
||||
We follow the workflow shown below.
|
||||
|
||||
<div align="left"><img src="figure/20230809161919.jpg" width="500"/>
|
||||
|
||||
First you should set the `DATA_SOURCE` in `path.sh` to the data path. Your data path should be organized as follow:
|
||||
```shell
|
||||
Test_2023_Ali_far_release
|
||||
|—— audio_dir/
|
||||
| |—— R1014_M1710.wav
|
||||
| |—— R1014_M1750.wav
|
||||
| |—— ...
|
||||
|—— textgrid_dir/
|
||||
| |—— R1014_M1710.textgrid
|
||||
| |—— R1014_M1750.textgrid
|
||||
| |—— ...
|
||||
|—— wav.scp
|
||||
|—— segments
|
||||
```
|
||||
|
||||
Then you can do speaker diarization with following command.
|
||||
```shell
|
||||
$ bash run_diar.sh
|
||||
```
|
||||
After diarization, you can check the result at the last line of `data/Test_2023_Ali_far_sond/dia_outputs/dia_result`. You should get a DER about 1.51%.
|
||||
|
||||
When you get the similar diarization result with us, then you can do the WPE and GSS using the following command.
|
||||
```shell
|
||||
$ bash run_enh.sh 8
|
||||
```
|
||||
|
||||
The number 8 should be replaced with the channel number of your dataset. Here we use the AliMeeting corpus which has 8 channels.
|
||||
|
||||
Finally, you can decode the processed audio with the pre-trained ASR model directly using the flollowing commands.
|
||||
```shell
|
||||
$ bash run_asr.sh --stage 0 --stop-stage 1
|
||||
$ bash run_asr.sh --stage 3 --stop-stage 3
|
||||
```
|
||||
The ASR result is saved at `./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/decode_Test_2023_Ali_far_wpegss/text_cpcer`.
|
||||
|
||||
# Infer on the AliMeeting Test_2023 set after finetune
|
||||
You can finetune the pre-trained ASR model with the AliMeeting train set to obtain a further reduction on the cpCER. To infer on the AliMeeting Test 2023 set after finetuning, you can run this commands after the train set is processed with WPE and GSS mentioned above.
|
||||
```shell
|
||||
$ bash run_asr.sh --stage 2 --stop-stage 3
|
||||
```
|
||||
|
||||
# Infer with your own dataset
|
||||
We also support infer with your own dataset. Your dataset should be organized as above. The `wav.scp` and `segments` file should format as:
|
||||
```shell
|
||||
# wav.scp
|
||||
sessionA wav_path/wav_name_A.wav
|
||||
sessionB wav_path/wav_name_B.wav
|
||||
sessionC wav_path/wav_name_C.wav
|
||||
...
|
||||
|
||||
# segments
|
||||
sessionA-start_time-end_time sessionA start_time end_time
|
||||
sessionB-start_time-end_time sessionA start_time end_time
|
||||
sessionC-start_time-end_time sessionA start_time end_time
|
||||
...
|
||||
```
|
||||
Then you should set the `DATA_SOURCE` and `DATA_NAME` in `path.sh`. The rest of the process is the same as [Infer on the AliMeeting Test_2023 set](#infer-on-the-alimeeting-test_2023-set).
|
||||
|
||||
# Result
|
||||
|
||||
| |VBx DER(%) | SOND DER(%)|cp-CER(%) |
|
||||
|:---------------|:------------:|:------------:|----------:|
|
||||
|before finetune | 16.87 | 1.51 | 10.18 |
|
||||
|after finetune | 16.87 | 1.51 | |
|
||||
|
||||
11
egs/alimeeting/modular_sa_asr/conf/mfcc_hires.conf
Normal file
11
egs/alimeeting/modular_sa_asr/conf/mfcc_hires.conf
Normal file
@ -0,0 +1,11 @@
|
||||
# config for high-resolution MFCC features, intended for neural network training.
|
||||
# Note: we keep all cepstra, so it has the same info as filterbank features,
|
||||
# but MFCC is more easily compressible (because less correlated) which is why
|
||||
# we prefer this method.
|
||||
--use-energy=false # use average of log energy, not energy.
|
||||
--sample-frequency=16000
|
||||
--num-mel-bins=40
|
||||
--num-ceps=40
|
||||
--low-freq=40
|
||||
--high-freq=-400
|
||||
|
||||
@ -0,0 +1 @@
|
||||
--norm-means=false --norm-vars=false
|
||||
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/final.raw
Executable file
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/final.raw
Executable file
Binary file not shown.
@ -0,0 +1 @@
|
||||
3
|
||||
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda.mat
Executable file
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda.mat
Executable file
Binary file not shown.
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda_stats
Executable file
BIN
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda_stats
Executable file
Binary file not shown.
@ -0,0 +1 @@
|
||||
[ 30 2 1 ]
|
||||
@ -0,0 +1 @@
|
||||
0
|
||||
BIN
egs/alimeeting/modular_sa_asr/figure/20230809161919.jpg
Normal file
BIN
egs/alimeeting/modular_sa_asr/figure/20230809161919.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 146 KiB |
103
egs/alimeeting/modular_sa_asr/local/compute_cpcer.py
Normal file
103
egs/alimeeting/modular_sa_asr/local/compute_cpcer.py
Normal file
@ -0,0 +1,103 @@
|
||||
import editdistance
|
||||
import sys
|
||||
import os
|
||||
from itertools import permutations
|
||||
|
||||
|
||||
def load_transcripts(file_path):
|
||||
trans_list = []
|
||||
for one_line in open(file_path, "rt"):
|
||||
meeting_id, trans = one_line.strip().split(" ")
|
||||
trans_list.append((meeting_id.strip(), trans.strip()))
|
||||
|
||||
return trans_list
|
||||
|
||||
def calc_spk_trans(trans):
|
||||
spk_trans_ = [x.strip() for x in trans.split("$")]
|
||||
spk_trans = []
|
||||
for i in range(len(spk_trans_)):
|
||||
spk_trans.append((str(i), spk_trans_[i]))
|
||||
return spk_trans
|
||||
|
||||
def calc_cer(ref_trans, hyp_trans):
|
||||
ref_spk_trans = calc_spk_trans(ref_trans)
|
||||
hyp_spk_trans = calc_spk_trans(hyp_trans)
|
||||
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
|
||||
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
|
||||
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
|
||||
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
|
||||
|
||||
errors, counts, permutes = [], [], []
|
||||
min_error = 0
|
||||
cost_dict = {}
|
||||
for perm in permutations(range(num_spk)):
|
||||
flag = True
|
||||
p_err, p_count = 0, 0
|
||||
for idx, p in enumerate(perm):
|
||||
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
|
||||
flag = False
|
||||
break
|
||||
cost_key = "{}-{}".format(idx, p)
|
||||
if cost_key in cost_dict:
|
||||
_e = cost_dict[cost_key]
|
||||
else:
|
||||
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
|
||||
cost_dict[cost_key] = _e
|
||||
if _e > min_error > 0:
|
||||
flag = False
|
||||
break
|
||||
p_err += _e
|
||||
p_count += len(ref_spk_trans[idx][1])
|
||||
|
||||
if flag:
|
||||
if p_err < min_error or min_error == 0:
|
||||
min_error = p_err
|
||||
|
||||
errors.append(p_err)
|
||||
counts.append(p_count)
|
||||
permutes.append(perm)
|
||||
|
||||
sd_cer = [(err, cnt, err/cnt, permute)
|
||||
for err, cnt, permute in zip(errors, counts, permutes)]
|
||||
best_rst = min(sd_cer, key=lambda x: x[2])
|
||||
|
||||
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
|
||||
|
||||
|
||||
def main():
|
||||
ref=sys.argv[1]
|
||||
hyp=sys.argv[2]
|
||||
result_path="/".join(hyp.split("/")[:-1]) + "/text_cpcer"
|
||||
ref_list = load_transcripts(ref)
|
||||
hyp_list = load_transcripts(hyp)
|
||||
result_file = open(result_path,'w')
|
||||
record_2_spk = [0, 0]
|
||||
record_3_spk = [0, 0]
|
||||
record_4_spk = [0, 0]
|
||||
error, count = 0, 0
|
||||
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
|
||||
assert ref_id == hyp_id
|
||||
mid = ref_id
|
||||
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
|
||||
error, count = error + dist, count + length
|
||||
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
||||
ref_spk = len(ref_trans.split("$"))
|
||||
hyp_spk = len(hyp_trans.split("$"))
|
||||
if ref_spk == 2:
|
||||
record_2_spk[0] += dist
|
||||
record_2_spk[1] += length
|
||||
elif ref_spk == 3:
|
||||
record_3_spk[0] += dist
|
||||
record_3_spk[1] += length
|
||||
else:
|
||||
record_4_spk[0] += dist
|
||||
record_4_spk[1] += length
|
||||
print(record_2_spk[0]/record_2_spk[1]*100.0)
|
||||
print(record_3_spk[0]/record_3_spk[1]*100.0)
|
||||
print(record_4_spk[0]/record_4_spk[1]*100.0)
|
||||
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
|
||||
result_file.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
132
egs/alimeeting/modular_sa_asr/local/convert_label_to_rttm.py
Normal file
132
egs/alimeeting/modular_sa_asr/local/convert_label_to_rttm.py
Normal file
@ -0,0 +1,132 @@
|
||||
import os
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
import numpy as np
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
from collections import OrderedDict
|
||||
from tqdm import tqdm
|
||||
from scipy.ndimage import median_filter
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("label_txt", type=str)
|
||||
parser.add_argument("map_scp", type=str)
|
||||
parser.add_argument("out_rttm", type=str)
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--chunk_len", type=int, default=1600)
|
||||
parser.add_argument("--shift_len", type=int, default=400)
|
||||
parser.add_argument("--ignore_len", type=int, default=5)
|
||||
parser.add_argument("--smooth_size", type=int, default=7)
|
||||
parser.add_argument("--vote_prob", type=float, default=0.5)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out_rttm)):
|
||||
os.makedirs(os.path.dirname(args.out_rttm))
|
||||
|
||||
utt2labels = load_scp_as_list(args.label_txt, 'list')
|
||||
utt2labels = sorted(utt2labels, key=lambda x: x[0])
|
||||
meeting2map = load_scp_as_dict(args.map_scp)
|
||||
meeting2labels = OrderedDict()
|
||||
for utt_id, chunk_label in utt2labels:
|
||||
mid = utt_id.split("-")[0]
|
||||
if mid not in meeting2labels:
|
||||
meeting2labels[mid] = []
|
||||
meeting2labels[mid].append(chunk_label)
|
||||
task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
|
||||
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
with open(args.out_rttm, "wt") as fd:
|
||||
for results in result_list:
|
||||
fd.writelines(results)
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def sample2ms(sample, sr=16000):
|
||||
return int(float(sample) / sr * 100)
|
||||
|
||||
|
||||
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
|
||||
n_chunk = len(chunk_label_list)
|
||||
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
|
||||
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
|
||||
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
|
||||
weight = np.zeros((n_frame, 1), dtype=float)
|
||||
for i in range(n_chunk):
|
||||
raw_label = chunk_label_list[i]
|
||||
for k in range(len(raw_label)):
|
||||
if raw_label[k] == '<unk>':
|
||||
raw_label[k] = raw_label[k-1] if k > 0 else '0'
|
||||
chunk_multi_label = seq2arr(raw_label, n_spk)
|
||||
chunk_len = chunk_multi_label.shape[0]
|
||||
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
|
||||
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
|
||||
multi_labels = multi_labels / weight # normalizing vote
|
||||
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
|
||||
return multi_labels
|
||||
|
||||
|
||||
def calc_spk_turns(label_arr, spk_list):
|
||||
turn_list = []
|
||||
length = label_arr.shape[0]
|
||||
n_spk = label_arr.shape[1]
|
||||
for k in range(n_spk):
|
||||
if spk_list[k] == "None":
|
||||
continue
|
||||
in_utt = False
|
||||
start = 0
|
||||
for i in range(length):
|
||||
if label_arr[i, k] == 1 and in_utt is False:
|
||||
start = i
|
||||
in_utt = True
|
||||
if label_arr[i, k] == 0 and in_utt is True:
|
||||
turn_list.append([spk_list[k], start, i - start])
|
||||
in_utt = False
|
||||
if in_utt:
|
||||
turn_list.append([spk_list[k], start, length - start])
|
||||
return turn_list
|
||||
|
||||
|
||||
def smooth_multi_labels(multi_label, win_len):
|
||||
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
|
||||
return multi_label
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
|
||||
template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
|
||||
results = []
|
||||
for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
|
||||
utt2map = load_scp_as_list(map_file_path, 'list')
|
||||
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
|
||||
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
|
||||
org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
|
||||
org_multi_labels = np.zeros((org_len, args.n_spk))
|
||||
for seg_id, [org_st, org_ed, st, ed] in utt2map:
|
||||
org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
|
||||
st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
|
||||
ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
|
||||
org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
|
||||
spk_turns = calc_spk_turns(org_multi_labels, spk_list)
|
||||
spk_turns = sorted(spk_turns, key=lambda x: x[1])
|
||||
for spk, st, dur in spk_turns:
|
||||
# TODO: handle the leak of segments at the change points
|
||||
if dur > args.ignore_len:
|
||||
results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -0,0 +1,104 @@
|
||||
import codecs
|
||||
import sys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
import numpy as np
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
|
||||
data_path = sys.argv[1]
|
||||
segment_file_path = data_path + "/segments_nooverlap"
|
||||
utt2spk_file_path = data_path + "/utt2spk_nooverlap"
|
||||
wav_scp_path = data_path + "/wav.scp"
|
||||
cluster_emb_dir = data_path + '/cluster_embedding/'
|
||||
os.system("mkdir -p " + cluster_emb_dir)
|
||||
cluster_profile_dir = data_path + '/cluster_profile_zeropadding16/'
|
||||
os.system('mkdir -p ' + cluster_profile_dir)
|
||||
|
||||
utt2spk = {}
|
||||
spk2seg = {}
|
||||
with codecs.open(utt2spk_file_path, "r", "utf-8") as f1:
|
||||
with codecs.open(segment_file_path, "r", "utf-8") as f2:
|
||||
for line in f1.readlines():
|
||||
uttid, spkid = line.strip().split(" ")
|
||||
utt2spk[uttid] = spkid
|
||||
for line in f2.readlines():
|
||||
uttid, sessionid, stime, etime = line.strip().split(" ")
|
||||
spkid = utt2spk[uttid]
|
||||
if spkid not in spk2seg.keys():
|
||||
spk2seg[spkid] = [(int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000))]
|
||||
else:
|
||||
spk2seg[spkid].append((int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000)))
|
||||
|
||||
inference_sv_pipline = pipeline(
|
||||
task=Tasks.speaker_verification,
|
||||
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch',
|
||||
device='gpu'
|
||||
)
|
||||
|
||||
wav_dict = {}
|
||||
|
||||
with codecs.open(wav_scp_path, "r", "utf-8") as fi:
|
||||
with codecs.open(data_path + "/cluster_embedding.scp", "w", "utf-8") as fo:
|
||||
for line in fi.readlines():
|
||||
sessionid, wav_path = line.strip().split()
|
||||
wav_dict[sessionid] = wav_path
|
||||
for spkid, segs in spk2seg.items():
|
||||
sessionid = spkid.split("-")[0]
|
||||
wav_path = wav_dict[sessionid]
|
||||
wav = soundfile.read(wav_path)[0]
|
||||
if wav.ndim == 2:
|
||||
wav = wav[:, 0]
|
||||
all_seg_embedding_list=[]
|
||||
for seg in segs:
|
||||
if seg[0] < wav.shape[0] - 0.5 * 16000:
|
||||
if seg[1] > wav.shape[0]:
|
||||
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: ])["spk_embedding"]
|
||||
else:
|
||||
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: seg[0] + seg[1]])["spk_embedding"]
|
||||
all_seg_embedding_list.append(cur_seg_embedding)
|
||||
all_seg_embedding = np.vstack(all_seg_embedding_list)
|
||||
spk_embedding = np.mean(all_seg_embedding, axis=0)
|
||||
np.save(cluster_emb_dir + spkid + '.npy', spk_embedding)
|
||||
fo.write(spkid + ' ' + cluster_emb_dir + spkid + '.npy' + '\n')
|
||||
|
||||
session2embs = {}
|
||||
|
||||
with codecs.open(data_path + "/cluster_embedding.scp", "r", "utf-8") as fi:
|
||||
with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "w", "utf-8") as fo:
|
||||
for line in fi.readlines():
|
||||
spkid, emb_path = line.strip().split(" ")
|
||||
sessionid = spkid.split("-")[0]
|
||||
if sessionid not in session2embs.keys():
|
||||
session2embs[sessionid] = [emb_path]
|
||||
else:
|
||||
session2embs[sessionid].append(emb_path)
|
||||
for sessionid, embs in session2embs.items():
|
||||
emb_list = [np.load(x) for x in embs]
|
||||
tmp = []
|
||||
for i in range(len(emb_list) - 1):
|
||||
flag = True
|
||||
for j in range(i + 1, len(emb_list)):
|
||||
cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
|
||||
if cos_sim > 0.99:
|
||||
flag = False
|
||||
if flag:
|
||||
tmp.append(emb_list[i][np.newaxis, :])
|
||||
tmp.append(emb_list[-1][np.newaxis, :])
|
||||
emb_list = tmp
|
||||
# tmp = []
|
||||
# for i in range(len(emb_list)):
|
||||
# for emb in tmp:
|
||||
# cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
|
||||
# if cos_sim > 0.99:
|
||||
# flag = False
|
||||
# if flag:
|
||||
# tmp.append(emb_list[i][np.newaxis, :])
|
||||
# emb_list = tmp
|
||||
for i in range(16 - len(emb_list)):
|
||||
emb_list.append(np.zeros((1, 256)))
|
||||
emb = np.concatenate(emb_list, axis=0)
|
||||
save_path = cluster_profile_dir + sessionid + ".npy"
|
||||
np.save(save_path, emb)
|
||||
fo.write("%s %s\n" % (sessionid, save_path))
|
||||
34
egs/alimeeting/modular_sa_asr/local/finetune.py
Normal file
34
egs/alimeeting/modular_sa_asr/local/finetune.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.msdatasets.audio.asr_dataset import ASRDataset
|
||||
|
||||
def modelscope_finetune(params):
|
||||
if not os.path.exists(params.output_dir):
|
||||
os.makedirs(params.output_dir, exist_ok=True)
|
||||
# dataset split ["train", "validation"]
|
||||
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
|
||||
kwargs = dict(
|
||||
model=params.model,
|
||||
data_dir=ds_dict,
|
||||
dataset_type=params.dataset_type,
|
||||
work_dir=params.output_dir,
|
||||
batch_bins=params.batch_bins,
|
||||
max_epoch=params.max_epoch,
|
||||
lr=params.lr)
|
||||
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from funasr.utils.modelscope_param import modelscope_args
|
||||
params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||||
params.output_dir = "./checkpoint" # 模型保存路径
|
||||
params.data_path = "./data" # 数据路径,可以为modelscope中已上传数据,也可以是本地数据
|
||||
params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
|
||||
params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||
params.max_epoch = 20 # 最大训练轮数
|
||||
params.lr = 0.00005 # 设置学习率
|
||||
|
||||
modelscope_finetune(params)
|
||||
1
egs/alimeeting/modular_sa_asr/local/format_wav_scp.py
Symbolic link
1
egs/alimeeting/modular_sa_asr/local/format_wav_scp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../sa_asr/local/format_wav_scp.py
|
||||
1
egs/alimeeting/modular_sa_asr/local/format_wav_scp.sh
Symbolic link
1
egs/alimeeting/modular_sa_asr/local/format_wav_scp.sh
Symbolic link
@ -0,0 +1 @@
|
||||
../../sa_asr/local/format_wav_scp.sh
|
||||
27
egs/alimeeting/modular_sa_asr/local/infer_sond.py
Normal file
27
egs/alimeeting/modular_sa_asr/local/infer_sond.py
Normal file
@ -0,0 +1,27 @@
|
||||
from funasr.bin.diar_inference_launch import inference_launch
|
||||
import sys
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES']='7'
|
||||
|
||||
def main():
|
||||
diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
|
||||
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pb"
|
||||
input_dir = sys.argv[3] if len(sys.argv) > 3 else "./inputs"
|
||||
output_dir = sys.argv[4] if len(sys.argv) > 4 else "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
(input_dir + "/wav.scp", "speech", "sound"),
|
||||
(input_dir + "/profile.scp", "profile", "npy"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
num_workers=16,
|
||||
ngpu=1,
|
||||
)
|
||||
pipeline(data_path_and_name_and_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
58
egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
Executable file
58
egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
Executable file
@ -0,0 +1,58 @@
|
||||
import argparse
|
||||
import tqdm
|
||||
import codecs
|
||||
import textgrid
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, spkr, stime, etime, text):
|
||||
self.uttid = uttid
|
||||
self.spkr = spkr
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.text = text
|
||||
|
||||
def change_stime(self, time):
|
||||
self.stime = time
|
||||
|
||||
def change_etime(self, time):
|
||||
self.etime = time
|
||||
|
||||
|
||||
def main(args):
|
||||
tg = textgrid.TextGrid.fromFile(args.input_textgrid_file)
|
||||
segments = []
|
||||
spk = {}
|
||||
num_spk = 1
|
||||
uttid = args.uttid
|
||||
for i in range(tg.__len__()):
|
||||
for j in range(tg[i].__len__()):
|
||||
if tg[i][j].mark:
|
||||
if tg[i].name not in spk:
|
||||
spk[tg[i].name] = num_spk
|
||||
num_spk += 1
|
||||
segments.append(
|
||||
Segment(
|
||||
uttid,
|
||||
spk[tg[i].name],
|
||||
tg[i][j].minTime,
|
||||
tg[i][j].maxTime,
|
||||
tg[i][j].mark.strip(),
|
||||
)
|
||||
)
|
||||
segments = sorted(segments, key=lambda x: x.stime)
|
||||
|
||||
rttm_file = codecs.open(args.output_rttm_file, "w", "utf-8")
|
||||
|
||||
for i in range(len(segments)):
|
||||
fmt = "SPEAKER {:s} 1 {:.2f} {:.2f} <NA> <NA> {:s} <NA> <NA>"
|
||||
#pdb.set_trace()
|
||||
rttm_file.write(f"{fmt.format(segments[i].uttid, float(segments[i].stime), float(segments[i].etime) - float(segments[i].stime), str(segments[i].spkr))}\n")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Make rttm for true label")
|
||||
parser.add_argument("--input_textgrid_file", required=True, help="The textgrid file")
|
||||
parser.add_argument("--output_rttm_file", required=True, help="The output rttm file")
|
||||
parser.add_argument("--uttid", required=True, help="The utt id of the file")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
83
egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
Executable file
83
egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
Executable file
@ -0,0 +1,83 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Process the textgrid files
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, spkr, stime, etime, text):
|
||||
self.uttid = uttid
|
||||
self.spkr = spkr
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.text = text
|
||||
|
||||
def change_stime(self, time):
|
||||
self.stime = time
|
||||
|
||||
def change_etime(self, time):
|
||||
self.etime = time
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="process the textgrid files")
|
||||
parser.add_argument("--path", type=str, required=True, help="textgrid path")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="label rttm file path")
|
||||
parser.add_argument("--predict_path", type=str, required=True, help="predict rttm file path")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
textgrid_flist = codecs.open(Path(args.path)/"uttid_textgrid.flist", "r", "utf-8")
|
||||
|
||||
|
||||
# parse the textgrid file for each utterance
|
||||
speaker2_uttidset = []
|
||||
speaker3_uttidset = []
|
||||
speaker4_uttidset = []
|
||||
for line in textgrid_flist:
|
||||
uttid ,textgrid_file = line.strip().split("\t")
|
||||
tg = textgrid.TextGrid()
|
||||
tg.read(textgrid_file)
|
||||
|
||||
num_speaker = len(tg)
|
||||
if num_speaker ==2:
|
||||
speaker2_uttidset.append(uttid)
|
||||
elif num_speaker ==3:
|
||||
speaker3_uttidset.append(uttid)
|
||||
elif num_speaker ==4:
|
||||
speaker4_uttidset.append(uttid)
|
||||
textgrid_flist.close()
|
||||
|
||||
speaker2_id_label = codecs.open(Path(args.label_path) / "speaker2_id", "w", "utf-8")
|
||||
speaker2_id_predict = codecs.open(Path(args.predict_path) / "speaker2_id", "w", "utf-8")
|
||||
speaker3_id_label = codecs.open(Path(args.label_path) / "speaker3_id", "w", "utf-8")
|
||||
speaker3_id_predict = codecs.open(Path(args.predict_path) / "speaker3_id", "w", "utf-8")
|
||||
speaker4_id_label = codecs.open(Path(args.label_path) / "speaker4_id", "w", "utf-8")
|
||||
speaker4_id_predict = codecs.open(Path(args.predict_path) / "speaker4_id", "w", "utf-8")
|
||||
|
||||
for i in range(len(speaker2_uttidset)):
|
||||
speaker2_id_label.write("%s\n" % (args.label_path+"/"+speaker2_uttidset[i]+".rttm"))
|
||||
speaker2_id_predict.write("%s\n" % (args.predict_path+"/"+speaker2_uttidset[i]+".rttm"))
|
||||
for i in range(len(speaker3_uttidset)):
|
||||
speaker3_id_label.write("%s\n" % (args.label_path+"/"+speaker3_uttidset[i]+".rttm"))
|
||||
speaker3_id_predict.write("%s\n" % (args.predict_path+"/"+speaker3_uttidset[i]+".rttm"))
|
||||
for i in range(len(speaker4_uttidset)):
|
||||
speaker4_id_label.write("%s\n" % (args.label_path+"/"+speaker4_uttidset[i]+".rttm"))
|
||||
speaker4_id_predict.write("%s\n" % (args.predict_path+"/"+speaker4_uttidset[i]+".rttm"))
|
||||
|
||||
speaker2_id_label.close()
|
||||
speaker2_id_predict.close()
|
||||
speaker3_id_label.close()
|
||||
speaker3_id_predict.close()
|
||||
speaker4_id_label.close()
|
||||
speaker4_id_predict.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
52
egs/alimeeting/modular_sa_asr/local/merge_spk_text.py
Normal file
52
egs/alimeeting/modular_sa_asr/local/merge_spk_text.py
Normal file
@ -0,0 +1,52 @@
|
||||
import sys
|
||||
import codecs
|
||||
import zhconv
|
||||
|
||||
decode_result = sys.argv[1]
|
||||
utt2spk_file = sys.argv[2]
|
||||
merged_result = "/".join(decode_result.split("/")[:-1]) + "/text_merge"
|
||||
|
||||
utt2text = {}
|
||||
utt2spk = {}
|
||||
spk2texts = {}
|
||||
spk2text = {}
|
||||
meeting2text = {}
|
||||
|
||||
with codecs.open(decode_result, "r", "utf-8") as f1:
|
||||
with codecs.open(utt2spk_file, "r", "utf-8") as f2:
|
||||
for line in f1.readlines():
|
||||
try:
|
||||
line_list = line.strip().split()
|
||||
uttid = line_list[0]
|
||||
text = "".join(line_list[1:])
|
||||
except:
|
||||
continue
|
||||
utt2text[uttid] = text
|
||||
for line in f2.readlines():
|
||||
uttid, spkid = line.strip().split()
|
||||
utt2spk[uttid] = spkid
|
||||
|
||||
for utt, text in utt2text.items():
|
||||
spk = utt2spk[utt]
|
||||
stime = int(utt.split("-")[-2])
|
||||
if spk in spk2texts.keys():
|
||||
spk2texts[spk].append([stime, text])
|
||||
else:
|
||||
spk2texts[spk] = [[stime, text]]
|
||||
|
||||
for spk, texts in spk2texts.items():
|
||||
texts = sorted(texts, key=lambda x: x[0])
|
||||
text = "".join([x[1] for x in texts])
|
||||
spk2text[spk] = text
|
||||
|
||||
with codecs.open(merged_result, "w", "utf-8") as f:
|
||||
for spk, text in spk2text.items():
|
||||
# meeting = spk.split("-")[2]
|
||||
meeting = spk.split("-")[0]
|
||||
if meeting in meeting2text.keys():
|
||||
meeting2text[meeting] = meeting2text[meeting] + "$" + text
|
||||
else:
|
||||
meeting2text[meeting] = text
|
||||
for meeting, text in meeting2text.items():
|
||||
f.write("%s %s\n" % (meeting, text))
|
||||
|
||||
140
egs/alimeeting/modular_sa_asr/local/remove_overlap.py
Executable file
140
egs/alimeeting/modular_sa_asr/local/remove_overlap.py
Executable file
@ -0,0 +1,140 @@
|
||||
import sys
|
||||
import pdb
|
||||
import codecs
|
||||
import os
|
||||
|
||||
input_segments_file = sys.argv[1]
|
||||
input_utt2spk_file = sys.argv[2]
|
||||
output_segments_file = sys.argv[3]
|
||||
output_utt2spk_file = sys.argv[4]
|
||||
threshold = sys.argv[5]
|
||||
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, baseid, spkid, meetingid, stime, etime, uttid=None):
|
||||
self.baseid = baseid
|
||||
self.spkid = spkid
|
||||
self.meetingid = meetingid
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.uttid = uttid
|
||||
self.dur = self.etime - self.stime
|
||||
if self.uttid is None:
|
||||
self.uttid = "%s-%s-%07d-%07d" % (
|
||||
self.baseid,
|
||||
self.spkid,
|
||||
self.stime * 100,
|
||||
self.etime * 100,
|
||||
)
|
||||
|
||||
|
||||
def cut(cur_max_end_time, seg_list, cur_seg, next_c):
|
||||
global out_segment_dict
|
||||
if next_c == len(seg_list):
|
||||
single_stime = max(cur_max_end_time, cur_seg.stime)
|
||||
single_etime = cur_seg.etime
|
||||
if single_stime < single_etime and single_etime - single_stime > float(threshold):
|
||||
# only save segment which duration more than threshold for sv's accuracy
|
||||
if cur_seg.spkid not in out_segment_dict.keys():
|
||||
out_segment_dict[cur_seg.spkid] = [
|
||||
Segment(
|
||||
cur_seg.baseid,
|
||||
cur_seg.spkid,
|
||||
cur_seg.meetingid,
|
||||
single_stime,
|
||||
single_etime,
|
||||
)]
|
||||
else:
|
||||
out_segment_dict[cur_seg.spkid].append(
|
||||
Segment(
|
||||
cur_seg.baseid,
|
||||
cur_seg.spkid,
|
||||
cur_seg.meetingid,
|
||||
single_stime,
|
||||
single_etime,
|
||||
)
|
||||
)
|
||||
else:
|
||||
next_seg = seg_list[next_c]
|
||||
single_stime = max(cur_max_end_time, cur_seg.stime)
|
||||
single_etime = min(cur_seg.etime, next_seg.stime)
|
||||
if single_stime < single_etime and single_etime - single_stime > float(threshold):
|
||||
if cur_seg.spkid not in out_segment_dict.keys():
|
||||
out_segment_dict[cur_seg.spkid] = [
|
||||
Segment(
|
||||
cur_seg.baseid,
|
||||
cur_seg.spkid,
|
||||
cur_seg.meetingid,
|
||||
single_stime,
|
||||
single_etime,
|
||||
)]
|
||||
else:
|
||||
out_segment_dict[cur_seg.spkid].append(
|
||||
Segment(
|
||||
cur_seg.baseid,
|
||||
cur_seg.spkid,
|
||||
cur_seg.meetingid,
|
||||
single_stime,
|
||||
single_etime,
|
||||
)
|
||||
)
|
||||
if cur_seg.etime > next_seg.etime:
|
||||
cut(max(cur_max_end_time, next_seg.etime), seg_list, cur_seg, next_c + 1)
|
||||
|
||||
|
||||
meeting2seg = {}
|
||||
utt2spk = {}
|
||||
i = 0
|
||||
|
||||
with codecs.open(input_utt2spk_file, "r", "utf-8") as f:
|
||||
for line in f.readlines():
|
||||
utt, spk = line.strip().split()
|
||||
utt2spk[utt] = spk
|
||||
|
||||
with codecs.open(input_segments_file, "r", "utf-8") as f:
|
||||
for line in f.readlines():
|
||||
i += 1
|
||||
uttid, meetingid, stime, etime = line.strip().split(" ")
|
||||
spkid = utt2spk[uttid].split("-")[1]
|
||||
baseid = meetingid
|
||||
one_seg = Segment(baseid, spkid, meetingid, float(stime), float(etime))
|
||||
if one_seg.meetingid not in meeting2seg.keys():
|
||||
meeting2seg[one_seg.meetingid] = [one_seg]
|
||||
else:
|
||||
meeting2seg[one_seg.meetingid].append(one_seg)
|
||||
|
||||
out_segment_dict = {}
|
||||
|
||||
for k, v in meeting2seg.items():
|
||||
meeting2seg[k] = sorted(v, key=lambda x: x.stime)
|
||||
cur_max_end_time = 0
|
||||
for i in range(len(v)):
|
||||
cur_seg = meeting2seg[k][i]
|
||||
cut(cur_max_end_time, meeting2seg[k], cur_seg, i + 1)
|
||||
cur_max_end_time = max(cur_max_end_time, cur_seg.etime)
|
||||
|
||||
out_segment_list = []
|
||||
|
||||
for k, v in out_segment_dict.items():
|
||||
out_segment_list.extend(out_segment_dict[k])
|
||||
|
||||
with codecs.open(output_segments_file, "w", "utf-8") as f_seg:
|
||||
with codecs.open(output_utt2spk_file, "w", "utf-8") as f_utt2spk:
|
||||
for out_seg in out_segment_list:
|
||||
f_seg.write(
|
||||
"%s %s %.2f %.2f\n"
|
||||
% (
|
||||
out_seg.uttid,
|
||||
out_seg.meetingid,
|
||||
out_seg.stime,
|
||||
out_seg.etime,
|
||||
)
|
||||
)
|
||||
f_utt2spk.write(
|
||||
"%s %s-%s\n"
|
||||
% (
|
||||
out_seg.uttid,
|
||||
out_seg.baseid,
|
||||
out_seg.spkid,
|
||||
)
|
||||
)
|
||||
78
egs/alimeeting/modular_sa_asr/local/resegment_data.py
Normal file
78
egs/alimeeting/modular_sa_asr/local/resegment_data.py
Normal file
@ -0,0 +1,78 @@
|
||||
import soundfile
|
||||
import os
|
||||
import sys
|
||||
import codecs
|
||||
import numpy as np
|
||||
import pdb
|
||||
|
||||
|
||||
segment_file_path = sys.argv[1]
|
||||
wav_scp_file_path = sys.argv[2]
|
||||
data_path = sys.argv[3]
|
||||
|
||||
wav_save_path = data_path + "/wav/"
|
||||
os.system("mkdir -p " + wav_save_path)
|
||||
pos_path = data_path + "/pos_map/"
|
||||
os.system("mkdir -p " + pos_path)
|
||||
|
||||
wav_dict = {}
|
||||
seg2time = {}
|
||||
seg2time_new = {}
|
||||
session2profile = {}
|
||||
|
||||
with codecs.open(wav_scp_file_path, "r", "utf-8") as f:
|
||||
for line in f.readlines():
|
||||
sessionid, wav_path = line.strip().split()
|
||||
wav_dict[sessionid] = wav_path
|
||||
|
||||
with codecs.open(segment_file_path, "r", "utf-8") as f:
|
||||
for line in f.readlines():
|
||||
_, sessionid, stime, etime = line.strip().split()
|
||||
if sessionid not in seg2time.keys():
|
||||
seg2time[sessionid] = [(int(16000 * float(stime)), int(16000 * float(etime)))]
|
||||
else:
|
||||
seg2time[sessionid].append((int(16000 * float(stime)), int(16000 * float(etime))))
|
||||
with codecs.open(data_path + "/map.scp", "w", "utf-8") as f1:
|
||||
for sessionid, seg_times in seg2time.items():
|
||||
seg2time_new[sessionid] = []
|
||||
last_time = 0
|
||||
with codecs.open(pos_path + sessionid + ".pos", "w", "utf-8") as f2:
|
||||
for seg_time in seg_times:
|
||||
tmp = seg_time[0] - last_time
|
||||
cur_seg = (seg_time[0] - tmp, seg_time[1] - tmp)
|
||||
seg2time_new[sessionid].append((seg_time[0] - last_time, seg_time[1] - last_time))
|
||||
last_time = cur_seg[1]
|
||||
f2.write("%s-%07d-%07d %d %d %d %d\n" % (sessionid, seg_time[0]/160, seg_time[1]/160, seg_time[0], seg_time[1], cur_seg[0], cur_seg[1]))
|
||||
f1.write("%s %s\n" % (sessionid, pos_path + sessionid + ".pos"))
|
||||
|
||||
with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "r", "utf-8") as f:
|
||||
for line in f.readlines():
|
||||
session, path = line.strip().split()
|
||||
session2profile[session] = path
|
||||
|
||||
with codecs.open(data_path + "/wav.scp", "w", "utf-8") as f1:
|
||||
with codecs.open(data_path + "/profile.scp", "w", "utf-8") as f2:
|
||||
for sessionid, wav_path in wav_dict.items():
|
||||
wav = soundfile.read(wav_path)[0]
|
||||
if wav.ndim == 2:
|
||||
wav = wav[:, 0]
|
||||
seg_list = [wav[seg[0]: seg[1]] for seg in seg2time[sessionid]]
|
||||
wav_new = np.concatenate(seg_list, axis=0)
|
||||
cur_time = 0
|
||||
flag = True
|
||||
while flag:
|
||||
start = cur_time
|
||||
end = cur_time + 256000
|
||||
if end < wav_new.shape[0]:
|
||||
cur_wav = wav_new[start: end]
|
||||
else:
|
||||
cur_wav = wav_new[start: ]
|
||||
flag = False
|
||||
cur_time = cur_time + 64000
|
||||
wav_name = "%s-%07d_%07d.wav" % (sessionid, start/160, end/160)
|
||||
soundfile.write(wav_save_path + wav_name, cur_wav, 16000)
|
||||
f1.write("%s %s\n" % (wav_name, wav_save_path + wav_name))
|
||||
f2.write("%s %s\n" % (wav_name, session2profile[sessionid]))
|
||||
|
||||
|
||||
|
||||
29
egs/alimeeting/modular_sa_asr/local/rttm2segments.py
Normal file
29
egs/alimeeting/modular_sa_asr/local/rttm2segments.py
Normal file
@ -0,0 +1,29 @@
|
||||
import codecs
|
||||
import sys
|
||||
|
||||
rttm_file_path = sys.argv[1]
|
||||
segment_file_path = sys.argv[2]
|
||||
mode = sys.argv[3] # 0 for diarization, 1 for asr
|
||||
|
||||
|
||||
meeting2spk = {}
|
||||
|
||||
with codecs.open(rttm_file_path, "r", "utf-8") as fi:
|
||||
with codecs.open(segment_file_path + "/segments", "w", "utf-8") as f1:
|
||||
with codecs.open(segment_file_path + "/utt2spk", "w", "utf-8") as f2:
|
||||
for line in fi.readlines():
|
||||
_, sessionid, _, stime, dur, _, _, spkid, _, _ = line.strip().split(" ")
|
||||
if float(dur) < 0.3:
|
||||
continue
|
||||
uttid = "%s-%07d-%07d" % (sessionid, int(float(stime) * 100), int(float(stime) * 100 + float(dur) * 100))
|
||||
spkid = "%s-%s" % (sessionid, spkid)
|
||||
if int(mode) == 0:
|
||||
f1.write("%s %s %.2f %.2f\n" % (uttid, sessionid, float(stime), float(stime) + float(dur)))
|
||||
f2.write("%s %s\n" % (uttid, spkid))
|
||||
elif int(mode) == 1:
|
||||
f1.write("%s %s %.2f %.2f\n" % (uttid, spkid, float(stime), float(stime) + float(dur)))
|
||||
f2.write("%s %s\n" % (uttid, spkid))
|
||||
else:
|
||||
exit("mode only support 0 or 1!")
|
||||
|
||||
|
||||
139
egs/alimeeting/modular_sa_asr/local/run_gss.py
Normal file
139
egs/alimeeting/modular_sa_asr/local/run_gss.py
Normal file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python
|
||||
# -- coding: UTF-8
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from nara_wpe.utils import stft, istft
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wf
|
||||
from tqdm import tqdm
|
||||
|
||||
from test_gss import *
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser("Doing GSS based enhancement.")
|
||||
parser.add_argument(
|
||||
"--wav-scp",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Wav scp file for enhancement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--segments",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Wav scp file for enhancement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory of GSS enhanced data.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def wfread(f):
|
||||
fs, data = wf.read(f)
|
||||
if data.dtype == np.int16:
|
||||
data = np.float32(data) / 32768
|
||||
return data, fs
|
||||
|
||||
|
||||
def wfwrite(z, fs, store_path):
|
||||
tmpwav = np.int16(z * 32768)
|
||||
wf.write(store_path, fs, tmpwav)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
|
||||
stft_window, stft_shift = 512, 256
|
||||
gss = GSS(iterations=20, iterations_post=1)
|
||||
bf = Beamformer("mvdrSouden_ban", "mask_mul")
|
||||
|
||||
with codecs.open(args.wav_scp, "r") as handle:
|
||||
lines_content = handle.readlines()
|
||||
wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
|
||||
|
||||
cnt = 0
|
||||
|
||||
session2spk2dur = {}
|
||||
with codecs.open(args.segments, "r") as handle:
|
||||
for line in handle.readlines():
|
||||
uttid, spkid, stime, etime = line.strip().split(" ")
|
||||
sessionid = spkid.split("-")[0]
|
||||
if sessionid not in session2spk2dur.keys():
|
||||
session2spk2dur[sessionid] = {}
|
||||
if spkid not in session2spk2dur[sessionid].keys():
|
||||
session2spk2dur[sessionid][spkid] = []
|
||||
session2spk2dur[sessionid][spkid].append((float(stime), float(etime)))
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
for wav_idx in tqdm(range(len(wav_lines)), leave=True, desc="0"):
|
||||
# get wav files from scp file
|
||||
file_list = wav_lines[wav_idx].split(" ")
|
||||
sessionid, wav_list = file_list[0], file_list[1:]
|
||||
|
||||
signal_list = []
|
||||
time_activity = []
|
||||
cnt += 1
|
||||
logging.info("Processing {}: {}".format(cnt, wav_list[0]))
|
||||
|
||||
# read all wavs
|
||||
for wav in wav_list:
|
||||
data, fs = wfread(wav)
|
||||
signal_list.append(data)
|
||||
try:
|
||||
obstft = np.stack(signal_list, axis=0)
|
||||
except:
|
||||
minlen = min([len(s) for s in signal_list])
|
||||
obstft = np.stack([s[:minlen] for s in signal_list])
|
||||
wavlen = obstft.shape[1]
|
||||
obstft = stft(obstft, stft_window, stft_shift)
|
||||
|
||||
# get activated timestamps and frequencies
|
||||
speaker_list = []
|
||||
for spk, dur in session2spk2dur[sessionid].items():
|
||||
speaker_list.append(spk.split("-")[-1])
|
||||
time_activity.append(get_time_activity(dur, wavlen, fs))
|
||||
time_activity.append([True] * wavlen)
|
||||
frequency_activity = get_frequency_activity(
|
||||
time_activity, stft_window, stft_shift
|
||||
)
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
# generate mask
|
||||
masks = gss(obstft, frequency_activity)
|
||||
masks_bak = masks
|
||||
|
||||
for i in range(masks.shape[0] - 1):
|
||||
target_mask = masks[i]
|
||||
distortion_mask = np.sum(np.delete(masks, i, axis=0), axis=0)
|
||||
xhat = bf(obstft, target_mask=target_mask, distortion_mask=distortion_mask)
|
||||
xhat = istft(xhat, stft_window, stft_shift)
|
||||
audio_dir = "/".join(wav_list[0].split("/")[:-1])
|
||||
store_path = (
|
||||
wav_list[0]
|
||||
.replace(audio_dir, args.output_dir)
|
||||
.replace(".wav", "-{}.wav".format(speaker_list[i]))
|
||||
)
|
||||
if not os.path.exists(os.path.split(store_path)[0]):
|
||||
os.makedirs(os.path.split(store_path)[0], exist_ok=True)
|
||||
|
||||
logging.info("Save wav file {}.".format(store_path))
|
||||
wfwrite(xhat, fs, store_path)
|
||||
masks = masks_bak
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
153
egs/alimeeting/modular_sa_asr/local/run_wpe.py
Normal file
153
egs/alimeeting/modular_sa_asr/local/run_wpe.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
# _*_ coding: UTF-8 _*_
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import os
|
||||
import logging
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wf
|
||||
from nara_wpe.utils import istft, stft
|
||||
from nara_wpe.wpe import wpe_v8 as wpe
|
||||
|
||||
|
||||
def wpe_worker(
|
||||
wav_scp,
|
||||
audio_dir="",
|
||||
output_dir="",
|
||||
channel=0,
|
||||
processing_id=None,
|
||||
processing_num=None,
|
||||
):
|
||||
sampling_rate = 16000
|
||||
iterations = 5
|
||||
stft_options = dict(
|
||||
size=512,
|
||||
shift=128,
|
||||
window_length=None,
|
||||
fading=True,
|
||||
pad=True,
|
||||
symmetric_window=False,
|
||||
)
|
||||
with codecs.open(wav_scp, "r") as handle:
|
||||
lines_content = handle.readlines()
|
||||
wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
|
||||
for wav_idx in range(len(wav_lines)):
|
||||
if processing_id is None:
|
||||
processing_token = True
|
||||
else:
|
||||
if wav_idx % processing_num == processing_id:
|
||||
processing_token = True
|
||||
else:
|
||||
processing_token = False
|
||||
if processing_token:
|
||||
wav_list = wav_lines[wav_idx].split(" ")
|
||||
file_exist = True
|
||||
for wav_path in wav_list:
|
||||
file_exist = file_exist and os.path.exists(
|
||||
wav_path.replace(audio_dir, output_dir)
|
||||
)
|
||||
if not file_exist:
|
||||
break
|
||||
if not file_exist:
|
||||
logging.info("wait to process {} : {}".format(wav_idx, wav_list[0]))
|
||||
signal_list = []
|
||||
for f in wav_list:
|
||||
_, data = wf.read(f)
|
||||
data = data[:, channel - 1]
|
||||
if data.dtype == np.int16:
|
||||
data = np.float32(data) / 32768
|
||||
signal_list.append(data)
|
||||
min_len = len(signal_list[0])
|
||||
max_len = len(signal_list[0])
|
||||
for i in range(1, len(signal_list)):
|
||||
min_len = min(min_len, len(signal_list[i]))
|
||||
max_len = max(max_len, len(signal_list[i]))
|
||||
if min_len != max_len:
|
||||
for i in range(len(signal_list)):
|
||||
signal_list[i] = signal_list[i][:min_len]
|
||||
y = np.stack(signal_list, axis=0)
|
||||
Y = stft(y, **stft_options).transpose(2, 0, 1)
|
||||
Z = wpe(Y, iterations=iterations, statistics_mode="full").transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
z = istft(Z, size=stft_options["size"], shift=stft_options["shift"])
|
||||
for d in range(len(signal_list)):
|
||||
store_path = wav_list[d].replace(audio_dir, output_dir)
|
||||
if not os.path.exists(os.path.split(store_path)[0]):
|
||||
os.makedirs(os.path.split(store_path)[0], exist_ok=True)
|
||||
tmpwav = np.int16(z[d, :] * 32768)
|
||||
wf.write(store_path, sampling_rate, tmpwav)
|
||||
else:
|
||||
logging.info("file exist {} : {}".format(wav_idx, wav_list[0]))
|
||||
return None
|
||||
|
||||
|
||||
def wpe_manager(
|
||||
wav_scp, processing_num=1, audio_dir="", output_dir="", channel=1
|
||||
):
|
||||
if processing_num > 1:
|
||||
pool = Pool(processes=processing_num)
|
||||
for i in range(processing_num):
|
||||
pool.apply_async(
|
||||
wpe_worker,
|
||||
kwds={
|
||||
"wav_scp": wav_scp,
|
||||
"processing_id": i,
|
||||
"processing_num": processing_num,
|
||||
"audio_dir": audio_dir,
|
||||
"output_dir": output_dir,
|
||||
},
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
wpe_worker(wav_scp, audio_dir=audio_dir, output_dir=output_dir, channel=channel)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("run_wpe")
|
||||
parser.add_argument(
|
||||
"--wav-scp",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path pf wav scp file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory of input audio files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory of WPE enhanced audio files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Channel number of input audio",
|
||||
)
|
||||
parser.add_argument("--nj", type=int, default="1", help="number of process")
|
||||
args = parser.parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
|
||||
logging.info("wavfile={}".format(args.wav_scp))
|
||||
logging.info("processingnum={}".format(args.nj))
|
||||
|
||||
wpe_manager(
|
||||
wav_scp=args.wav_scp,
|
||||
processing_num=args.nj,
|
||||
audio_dir=args.audio_dir,
|
||||
output_dir=args.output_dir,
|
||||
channel=int(args.channel)
|
||||
)
|
||||
58
egs/alimeeting/modular_sa_asr/local/segment_to_lab.py
Executable file
58
egs/alimeeting/modular_sa_asr/local/segment_to_lab.py
Executable file
@ -0,0 +1,58 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def read_segments_file(segments_file):
|
||||
utt2segments = dict()
|
||||
with open(segments_file, "r") as fr:
|
||||
lines = fr.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
segment_utt_id, utt_id, start, end = parts[0], parts[1], float(parts[2]), float(parts[3])
|
||||
if utt_id not in utt2segments:
|
||||
utt2segments[utt_id] = []
|
||||
utt2segments[utt_id].append((segment_utt_id, start, end))
|
||||
return utt2segments
|
||||
|
||||
|
||||
def write_label(label_file, label_list):
|
||||
with open(label_file, "w") as fw:
|
||||
for (start, end) in label_list:
|
||||
fw.write(f"{start} {end} sp\n")
|
||||
fw.flush()
|
||||
|
||||
|
||||
def write_label_scp_file(label_scp_file, label_scp: dict):
|
||||
with open(label_scp_file, "w") as fw:
|
||||
for (utt_id, label_path) in label_scp.items():
|
||||
fw.write(f"{utt_id} {label_path}\n")
|
||||
fw.flush()
|
||||
|
||||
|
||||
def main(args):
|
||||
input_segments = args.input_segments
|
||||
label_path = args.label_path
|
||||
output_label_scp_file = args.output_label_scp_file
|
||||
|
||||
utt2segments = read_segments_file(input_segments)
|
||||
print(f"Collect {len(utt2segments)} utt2segments in file {input_segments}")
|
||||
|
||||
result_label_scp = dict()
|
||||
for utt_id in utt2segments.keys():
|
||||
segment_list = utt2segments[utt_id]
|
||||
cur_label_path = os.path.join(label_path, f"{utt_id}.lab")
|
||||
write_label(cur_label_path, label_list=[(i1, i2) for (_, i1, i2) in segment_list])
|
||||
result_label_scp[utt_id] = cur_label_path
|
||||
write_label_scp_file(output_label_scp_file, result_label_scp)
|
||||
print(f"Write {len(result_label_scp)} labels")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Make the lab file for segments")
|
||||
parser.add_argument("--input_segments", required=True, help="The input segments file")
|
||||
parser.add_argument("--label_path", required=True, help="The label_path to save file.lab")
|
||||
parser.add_argument("--output_label_scp_file", required=True, help="The output label.scp file")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
225
egs/alimeeting/modular_sa_asr/local/segmentation/detect_speech_activity.sh
Executable file
225
egs/alimeeting/modular_sa_asr/local/segmentation/detect_speech_activity.sh
Executable file
@ -0,0 +1,225 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2016-17 Vimal Manohar
|
||||
# 2017 Nagendra Kumar Goel
|
||||
# Apache 2.0.
|
||||
|
||||
# This script does nnet3-based speech activity detection given an input
|
||||
# kaldi data directory and outputs a segmented kaldi data directory.
|
||||
# This script can also do music detection and other similar segmentation
|
||||
# using appropriate options such as --output-name output-music.
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
set -u
|
||||
|
||||
if [ -f ./path.sh ]; then . ./path.sh; fi
|
||||
|
||||
#export PATH=/usr/local/cuda-10.0/bin:$PATH
|
||||
#export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
#echo $PATH
|
||||
#echo $LD_LIBRARY_PATH
|
||||
|
||||
affix= # Affix for the segmentation
|
||||
nj=32
|
||||
cmd=run.pl
|
||||
stage=-1
|
||||
|
||||
# Feature options (Must match training)
|
||||
mfcc_config=conf/mfcc_hires.conf
|
||||
feat_affix= # Affix for the type of feature used
|
||||
|
||||
output_name=output # The output node in the network
|
||||
sad_name=sad # Base name for the directory storing the computed loglikes
|
||||
# Can be music for music detection
|
||||
segmentation_name=segmentation # Base name for the directory doing segmentation
|
||||
# Can be segmentation_music for music detection
|
||||
|
||||
# SAD network config
|
||||
iter=final # Model iteration to use
|
||||
|
||||
# Contexts must ideally match training for LSTM models, but
|
||||
# may not necessarily for stats components
|
||||
extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training)
|
||||
extra_right_context=0
|
||||
extra_left_context_initial=-1
|
||||
extra_right_context_final=-1
|
||||
frames_per_chunk=150
|
||||
|
||||
# Decoding options
|
||||
graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0"
|
||||
acwt=0.3
|
||||
|
||||
# These <from>_in_<to>_weight represent the fraction of <from> probability
|
||||
# to transfer to <to> class.
|
||||
# e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3
|
||||
transform_probs_opts=""
|
||||
|
||||
# Postprocessing options
|
||||
segment_padding=0.2 # Duration (in seconds) of padding added to segments
|
||||
min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included
|
||||
# This is before any padding. Segments shorter than this duration will be removed.
|
||||
# This is an alternative to --min-speech-duration above.
|
||||
merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many
|
||||
# seconds. The segments are only merged if their boundaries are touching.
|
||||
# This is after padding by --segment-padding seconds.
|
||||
# 0 means do not merge. Use 'inf' to not limit the duration.
|
||||
|
||||
echo $*
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 5 ]; then
|
||||
echo "This script does nnet3-based speech activity detection given an input kaldi "
|
||||
echo "data directory and outputs an output kaldi data directory."
|
||||
echo "See script for details of the options to be supplied."
|
||||
echo "Usage: $0 <src-data-dir> <sad-nnet-dir> <mfcc-dir> <work-dir> <out-data-dir>"
|
||||
echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 \\"
|
||||
echo " mfcc_hires exp/segmentation_sad_snr/nnet_tdnn_j_n4 data/ami_sdm1_dev"
|
||||
echo ""
|
||||
echo "Options: "
|
||||
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
|
||||
echo " --nj <num-job> # number of parallel jobs to run."
|
||||
echo " --stage <stage> # stage to do partial re-run from."
|
||||
echo " --convert-data-dir-to-whole <true|false> # If true, the input data directory is "
|
||||
echo " # first converted to whole data directory (i.e. whole recordings) "
|
||||
echo " # and segmentation is done on that."
|
||||
echo " # If false, then the original segments are "
|
||||
echo " # retained and they are split into sub-segments."
|
||||
echo " --output-name <name> # The output node in the network"
|
||||
echo " --extra-left-context <context|0> # Set to some large value, typically 40 for LSTM (must match training)"
|
||||
echo " --extra-right-context <context|0> # For BLSTM or statistics pooling"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
src_data_dir=$1 # The input data directory that needs to be segmented.
|
||||
# If convert_data_dir_to_whole is true, any segments in that will be ignored.
|
||||
sad_nnet_dir=$2 # The SAD neural network
|
||||
mfcc_dir=$3 # The directory to store the features
|
||||
dir=$4 # Work directory
|
||||
data_dir=$5 # The output data directory will be ${data_dir}_seg
|
||||
|
||||
affix=${affix:+_$affix}
|
||||
feat_affix=${feat_affix:+_$feat_affix}
|
||||
|
||||
data_id=`basename $data_dir`
|
||||
sad_dir=${dir}/${sad_name}${affix}_${data_id}${feat_affix}
|
||||
seg_dir=${dir}/${segmentation_name}${affix}_${data_id}${feat_affix}
|
||||
# test_data_dir=data/${data_id}${feat_affix}
|
||||
test_data_dir=${src_data_dir}
|
||||
|
||||
###############################################################################
|
||||
## Forward pass through the network network and dump the log-likelihoods.
|
||||
###############################################################################
|
||||
|
||||
frame_subsampling_factor=1
|
||||
if [ -f $sad_nnet_dir/frame_subsampling_factor ]; then
|
||||
frame_subsampling_factor=$(cat $sad_nnet_dir/frame_subsampling_factor)
|
||||
fi
|
||||
|
||||
mkdir -p $dir
|
||||
if [ $stage -le 1 ]; then
|
||||
if [ "$(readlink -f $sad_nnet_dir)" != "$(readlink -f $dir)" ]; then
|
||||
cp $sad_nnet_dir/cmvn_opts $dir || exit 1
|
||||
fi
|
||||
|
||||
########################################################################
|
||||
## Initialize neural network for decoding using the output $output_name
|
||||
########################################################################
|
||||
|
||||
if [ ! -z "$output_name" ] && [ "$output_name" != output ]; then
|
||||
$cmd $dir/log/get_nnet_${output_name}.log \
|
||||
nnet3-copy --edits="rename-node old-name=$output_name new-name=output" \
|
||||
$sad_nnet_dir/$iter.raw $dir/${iter}_${output_name}.raw || exit 1
|
||||
iter=${iter}_${output_name}
|
||||
else
|
||||
if ! diff $sad_nnet_dir/$iter.raw $dir/$iter.raw; then
|
||||
cp $sad_nnet_dir/$iter.raw $dir/
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ${test_data_dir}
|
||||
steps/nnet3/compute_output.sh --nj $nj --cmd "$cmd" \
|
||||
--iter ${iter} \
|
||||
--extra-left-context $extra_left_context \
|
||||
--extra-right-context $extra_right_context \
|
||||
--extra-left-context-initial $extra_left_context_initial \
|
||||
--extra-right-context-final $extra_right_context_final \
|
||||
--frames-per-chunk $frames_per_chunk --apply-exp true \
|
||||
--frame-subsampling-factor $frame_subsampling_factor \
|
||||
${test_data_dir} $dir $sad_dir || exit 1
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
## Prepare FST we search to make speech/silence decisions.
|
||||
###############################################################################
|
||||
|
||||
utils/data/get_utt2dur.sh --nj $nj --cmd "$cmd" $test_data_dir || exit 1
|
||||
frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) || exit 1
|
||||
|
||||
graph_dir=${dir}/graph_${output_name}
|
||||
if [ $stage -le 2 ]; then
|
||||
mkdir -p $graph_dir
|
||||
|
||||
# 1 for silence and 2 for speech
|
||||
cat <<EOF > $graph_dir/words.txt
|
||||
<eps> 0
|
||||
silence 1
|
||||
speech 2
|
||||
EOF
|
||||
|
||||
$cmd $graph_dir/log/make_graph.log \
|
||||
steps/segmentation/internal/prepare_sad_graph.py $graph_opts \
|
||||
--frame-shift=$(perl -e "print $frame_shift * $frame_subsampling_factor") - \| \
|
||||
fstcompile --isymbols=$graph_dir/words.txt --osymbols=$graph_dir/words.txt '>' \
|
||||
$graph_dir/HCLG.fst
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
## Do Viterbi decoding to create per-frame alignments.
|
||||
###############################################################################
|
||||
|
||||
post_vec=$sad_nnet_dir/post_${output_name}.vec
|
||||
if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then
|
||||
if [ ! -f $sad_nnet_dir/post_${output_name}.txt ]; then
|
||||
echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. "
|
||||
echo "Re-run the corresponding stage in the training script possibly "
|
||||
echo "with --compute-average-posteriors=true or compute the priors "
|
||||
echo "from the training labels"
|
||||
exit 1
|
||||
else
|
||||
post_vec=$sad_nnet_dir/post_${output_name}.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
mkdir -p $seg_dir
|
||||
if [ $stage -le 3 ]; then
|
||||
steps/segmentation/internal/get_transform_probs_mat.py \
|
||||
--priors="$post_vec" $transform_probs_opts > $seg_dir/transform_probs.mat
|
||||
|
||||
steps/segmentation/decode_sad.sh --acwt $acwt --cmd "$cmd" \
|
||||
--nj $nj \
|
||||
--transform "$seg_dir/transform_probs.mat" \
|
||||
$graph_dir $sad_dir $seg_dir
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
## Post-process segmentation to create kaldi data directory.
|
||||
###############################################################################
|
||||
|
||||
if [ $stage -le 4 ]; then
|
||||
steps/segmentation/post_process_sad_to_segments.sh \
|
||||
--segment-padding $segment_padding --min-segment-dur $min_segment_dur \
|
||||
--merge-consecutive-max-dur $merge_consecutive_max_dur \
|
||||
--cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \
|
||||
${test_data_dir} ${seg_dir} ${seg_dir}
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ]; then
|
||||
utils/data/subsegment_data_dir.sh ${test_data_dir} ${seg_dir}/segments \
|
||||
${data_dir}_seg
|
||||
fi
|
||||
|
||||
echo "$0: Created output segmented kaldi data directory in ${data_dir}_seg"
|
||||
exit 0
|
||||
141
egs/alimeeting/modular_sa_asr/local/test_gss.py
Normal file
141
egs/alimeeting/modular_sa_asr/local/test_gss.py
Normal file
@ -0,0 +1,141 @@
|
||||
import io
|
||||
import functools
|
||||
import logging
|
||||
|
||||
# import soundfile as sf
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
# from IPython.display import display, Audio
|
||||
|
||||
from nara_wpe.utils import stft, istft
|
||||
|
||||
from pb_bss.distribution import CACGMMTrainer
|
||||
from pb_bss.evaluation import InputMetrics, OutputMetrics
|
||||
from dataclasses import dataclass
|
||||
|
||||
# from beamforming_wrapper import beamform_mvdr_souden_from_masks
|
||||
from pb_chime5.utils.numpy_utils import segment_axis_v2
|
||||
from textgrid_processor import read_textgrid_from_file
|
||||
|
||||
|
||||
def get_time_activity(dur_list, wavlen, sr):
|
||||
time_activity = [False] * wavlen
|
||||
for dur in dur_list:
|
||||
xmax = int(dur[1] * sr)
|
||||
xmin = int(dur[0] * sr)
|
||||
if xmax > wavlen:
|
||||
continue
|
||||
for i in range(xmin, xmax):
|
||||
time_activity[i] = True
|
||||
logging.info("Num of actived samples {}".format(time_activity.count(True)))
|
||||
return time_activity
|
||||
|
||||
def get_frequency_activity(
|
||||
time_activity,
|
||||
stft_window_length,
|
||||
stft_shift,
|
||||
stft_fading=True,
|
||||
stft_pad=True,
|
||||
):
|
||||
time_activity = np.asarray(time_activity)
|
||||
|
||||
if stft_fading:
|
||||
pad_width = np.array([(0, 0)] * time_activity.ndim)
|
||||
pad_width[-1, :] = stft_window_length - stft_shift # Consider fading
|
||||
time_activity = np.pad(time_activity, pad_width, mode="constant")
|
||||
|
||||
return segment_axis_v2(
|
||||
time_activity,
|
||||
length=stft_window_length,
|
||||
shift=stft_shift,
|
||||
end="pad" if stft_pad else "cut",
|
||||
).any(axis=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Beamformer:
|
||||
type: str
|
||||
postfilter: str
|
||||
|
||||
def __call__(self, Obs, target_mask, distortion_mask, debug=False):
|
||||
bf = self.type
|
||||
|
||||
if bf == "mvdrSouden_ban":
|
||||
from pb_chime5.speech_enhancement.beamforming_wrapper import (
|
||||
beamform_mvdr_souden_from_masks,
|
||||
)
|
||||
|
||||
X_hat = beamform_mvdr_souden_from_masks(
|
||||
Y=Obs,
|
||||
X_mask=target_mask,
|
||||
N_mask=distortion_mask,
|
||||
ban=True,
|
||||
)
|
||||
elif bf == "ch0":
|
||||
X_hat = Obs[0]
|
||||
elif bf == "sum":
|
||||
X_hat = np.sum(Obs, axis=0)
|
||||
else:
|
||||
raise NotImplementedError(bf)
|
||||
|
||||
if self.postfilter is None:
|
||||
pass
|
||||
elif self.postfilter == "mask_mul":
|
||||
X_hat = X_hat * target_mask
|
||||
else:
|
||||
raise NotImplementedError(self.postfilter)
|
||||
|
||||
return X_hat
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSS:
|
||||
iterations: int = 20
|
||||
iterations_post: int = 0
|
||||
verbose: bool = True
|
||||
# use_pinv: bool = False
|
||||
# stable: bool = True
|
||||
|
||||
def __call__(self, Obs, acitivity_freq=None, debug=False):
|
||||
initialization = np.asarray(acitivity_freq, dtype=np.float64)
|
||||
initialization = np.where(initialization == 0, 1e-10, initialization)
|
||||
initialization = initialization / np.sum(initialization, keepdims=True, axis=0)
|
||||
initialization = np.repeat(initialization[None, ...], 257, axis=0)
|
||||
|
||||
source_active_mask = np.asarray(acitivity_freq, dtype=bool)
|
||||
source_active_mask = np.repeat(source_active_mask[None, ...], 257, axis=0)
|
||||
|
||||
cacGMM = CACGMMTrainer()
|
||||
|
||||
if debug:
|
||||
learned = []
|
||||
all_affiliations = []
|
||||
F = Obs.shape[-1]
|
||||
T = Obs.T.shape[-2]
|
||||
|
||||
for f in range(F):
|
||||
if self.verbose:
|
||||
if f % 50 == 0:
|
||||
logging.info(f"{f}/{F}")
|
||||
|
||||
# T: Consider end of signal.
|
||||
# This should not be nessesary, but activity is for inear and not for
|
||||
# array.
|
||||
cur = cacGMM.fit(
|
||||
y=Obs.T[f, ...],
|
||||
initialization=initialization[f, ..., :T],
|
||||
iterations=self.iterations,
|
||||
source_activity_mask=source_active_mask[f, ..., :T],
|
||||
)
|
||||
affiliation = cur.predict(
|
||||
Obs.T[f, ...],
|
||||
source_activity_mask=source_active_mask[f, ..., :T],
|
||||
)
|
||||
|
||||
all_affiliations.append(affiliation)
|
||||
|
||||
posterior = np.array(all_affiliations).transpose(1, 2, 0)
|
||||
|
||||
return posterior
|
||||
316
egs/alimeeting/modular_sa_asr/local/textgrid_processor.py
Normal file
316
egs/alimeeting/modular_sa_asr/local/textgrid_processor.py
Normal file
@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import codecs
|
||||
|
||||
|
||||
class TextGrid(object):
|
||||
def __init__(
|
||||
self,
|
||||
file_type="",
|
||||
object_class="",
|
||||
xmin=0.0,
|
||||
xmax=0.0,
|
||||
tiers_status="",
|
||||
tiers=[],
|
||||
):
|
||||
self.file_type = file_type
|
||||
self.object_class = object_class
|
||||
self.xmin = xmin
|
||||
self.xmax = xmax
|
||||
self.tiers_status = tiers_status
|
||||
self.tiers = tiers
|
||||
|
||||
if self.xmax < self.xmin:
|
||||
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
|
||||
|
||||
def cutoff(self, xstart=None, xend=None):
|
||||
if xstart is None:
|
||||
xstart = self.xmin
|
||||
|
||||
if xend is None:
|
||||
xend = self.xmax
|
||||
|
||||
if xend < xstart:
|
||||
raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
|
||||
|
||||
new_xmax = xend - xstart + self.xmin
|
||||
new_xmin = self.xmin
|
||||
new_tiers = []
|
||||
|
||||
for tier in self.tiers:
|
||||
new_tiers.append(tier.cutoff(xstart=xstart, xend=xend))
|
||||
return TextGrid(
|
||||
file_type=self.file_type,
|
||||
object_class=self.object_class,
|
||||
xmin=new_xmin,
|
||||
xmax=new_xmax,
|
||||
tiers_status=self.tiers_status,
|
||||
tiers=new_tiers,
|
||||
)
|
||||
|
||||
|
||||
class Tier(object):
|
||||
def __init__(self, tier_class="", name="", xmin=0.0, xmax=0.0, intervals=[]):
|
||||
self.tier_class = tier_class
|
||||
self.name = name
|
||||
self.xmin = xmin
|
||||
self.xmax = xmax
|
||||
self.intervals = intervals
|
||||
|
||||
if self.xmax < self.xmin:
|
||||
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
|
||||
|
||||
def cutoff(self, xstart=None, xend=None):
|
||||
if xstart is None:
|
||||
xstart = self.xmin
|
||||
|
||||
if xend is None:
|
||||
xend = self.xmax
|
||||
|
||||
if xend < xstart:
|
||||
raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
|
||||
|
||||
bias = xstart - self.xmin
|
||||
new_xmax = xend - bias
|
||||
new_xmin = self.xmin
|
||||
new_intervals = []
|
||||
for interval in self.intervals:
|
||||
if interval.xmax <= xstart or interval.xmin >= xend:
|
||||
pass
|
||||
elif interval.xmin < xstart:
|
||||
new_intervals.append(
|
||||
Interval(
|
||||
xmin=new_xmin, xmax=interval.xmax - bias, text=interval.text
|
||||
)
|
||||
)
|
||||
elif interval.xmax > xend:
|
||||
new_intervals.append(
|
||||
Interval(
|
||||
xmin=interval.xmin - bias, xmax=new_xmax, text=interval.text
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_intervals.append(
|
||||
Interval(
|
||||
xmin=interval.xmin - bias,
|
||||
xmax=interval.xmax - bias,
|
||||
text=interval.text,
|
||||
)
|
||||
)
|
||||
|
||||
return Tier(
|
||||
tier_class=self.tier_class,
|
||||
name=self.name,
|
||||
xmin=new_xmin,
|
||||
xmax=new_xmax,
|
||||
intervals=new_intervals,
|
||||
)
|
||||
|
||||
|
||||
class Interval(object):
|
||||
def __init__(self, xmin=0.0, xmax=0.0, text=""):
|
||||
self.xmin = xmin
|
||||
self.xmax = xmax
|
||||
self.text = text
|
||||
|
||||
if self.xmax < self.xmin:
|
||||
raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
|
||||
|
||||
|
||||
def read_textgrid_from_file(filepath):
|
||||
with codecs.open(filepath, "r", encoding="utf-8") as handle:
|
||||
lines = handle.readlines()
|
||||
|
||||
if lines[-1] == "\r\n":
|
||||
lines = lines[:-1]
|
||||
|
||||
assert "File type" in lines[0], "error line 0, {}".format(lines[0])
|
||||
file_type = (
|
||||
lines[0]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert "Object class" in lines[1], "error line 1, {}".format(lines[1])
|
||||
object_class = (
|
||||
lines[1]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert lines[2] == "\r\n", "error line 2, {}".format(lines[2])
|
||||
|
||||
assert "xmin" in lines[3], "error line 3, {}".format(lines[3])
|
||||
xmin = float(
|
||||
lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert "xmax" in lines[4], "error line 4, {}".format(lines[4])
|
||||
xmax = float(
|
||||
lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert "tiers?" in lines[5], "error line 5, {}".format(lines[5])
|
||||
tiers_status = (
|
||||
lines[5].split("?")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert "size" in lines[6], "error line 6, {}".format(lines[6])
|
||||
size = int(
|
||||
lines[6].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert lines[7] == "item []:\r\n", "error line 7, {}".format(lines[7])
|
||||
|
||||
tier_start = []
|
||||
for item_idx in range(size):
|
||||
tier_start.append(lines.index(" " * 4 + "item [{}]:\r\n".format(item_idx + 1)))
|
||||
|
||||
tier_end = tier_start[1:] + [len(lines)]
|
||||
|
||||
tiers = []
|
||||
for tier_idx in range(size):
|
||||
tiers.append(
|
||||
read_tier_from_lines(
|
||||
tier_lines=lines[tier_start[tier_idx] + 1 : tier_end[tier_idx]]
|
||||
)
|
||||
)
|
||||
|
||||
return TextGrid(
|
||||
file_type=file_type,
|
||||
object_class=object_class,
|
||||
xmin=xmin,
|
||||
xmax=xmax,
|
||||
tiers_status=tiers_status,
|
||||
tiers=tiers,
|
||||
)
|
||||
|
||||
|
||||
def read_tier_from_lines(tier_lines):
|
||||
assert "class" in tier_lines[0], "error line 0, {}".format(tier_lines[0])
|
||||
tier_class = (
|
||||
tier_lines[0]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert "name" in tier_lines[1], "error line 1, {}".format(tier_lines[1])
|
||||
name = (
|
||||
tier_lines[1]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert "xmin" in tier_lines[2], "error line 2, {}".format(tier_lines[2])
|
||||
xmin = float(
|
||||
tier_lines[2].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert "xmax" in tier_lines[3], "error line 3, {}".format(tier_lines[3])
|
||||
xmax = float(
|
||||
tier_lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
assert "intervals: size" in tier_lines[4], "error line 4, {}".format(tier_lines[4])
|
||||
intervals_num = int(
|
||||
tier_lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
|
||||
)
|
||||
|
||||
# handle unformatted case
|
||||
# R12_S203204205_C09_I1_Near_203.TextGrid
|
||||
# R12_S203204205_C09_I1_Near_205.TextGrid
|
||||
if tier_lines[-1] == "\n":
|
||||
tier_lines = tier_lines[:-1]
|
||||
|
||||
if len(tier_lines[5:]) == intervals_num * 5:
|
||||
intervals = []
|
||||
for intervals_idx in range(intervals_num):
|
||||
assert tier_lines[
|
||||
5 + 5 * intervals_idx + 0
|
||||
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
|
||||
assert tier_lines[
|
||||
5 + 5 * intervals_idx + 1
|
||||
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
|
||||
intervals.append(
|
||||
read_interval_from_lines(
|
||||
interval_lines=tier_lines[
|
||||
7 + 5 * intervals_idx : 10 + 5 * intervals_idx
|
||||
]
|
||||
)
|
||||
)
|
||||
elif len(tier_lines[5:]) == intervals_num * 4:
|
||||
# handle unformatted case
|
||||
# R12_S203204205_C09_I1_Near_203.TextGrid
|
||||
# R12_S203204205_C09_I1_Near_204.TextGrid
|
||||
# R12_S203204205_C09_I1_Near_205.TextGrid
|
||||
intervals = []
|
||||
for intervals_idx in range(intervals_num):
|
||||
assert tier_lines[
|
||||
5 + 4 * intervals_idx + 0
|
||||
] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
|
||||
|
||||
intervals.append(
|
||||
read_interval_from_lines(
|
||||
interval_lines=tier_lines[
|
||||
6 + 4 * intervals_idx : 9 + 4 * intervals_idx
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
raise ValueError(
|
||||
"error lines {} % {} != 0".format(len(tier_lines[5:]), intervals_num)
|
||||
)
|
||||
|
||||
return Tier(
|
||||
tier_class=tier_class, name=name, xmin=xmin, xmax=xmax, intervals=intervals
|
||||
)
|
||||
|
||||
|
||||
def read_interval_from_lines(interval_lines):
|
||||
assert len(interval_lines) == 3, "error lines"
|
||||
|
||||
assert "xmin" in interval_lines[0], "error line 0, {}".format(interval_lines[0])
|
||||
xmin = float(
|
||||
interval_lines[0]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert "xmax" in interval_lines[1], "error line 1, {}".format(interval_lines[1])
|
||||
xmax = float(
|
||||
interval_lines[1]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
assert "text" in interval_lines[2], "error line 2, {}".format(interval_lines[2])
|
||||
text = (
|
||||
interval_lines[2]
|
||||
.split("=")[1]
|
||||
.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("\r", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
|
||||
return Interval(xmin=xmin, xmax=xmax, text=text)
|
||||
14
egs/alimeeting/modular_sa_asr/path.sh
Executable file
14
egs/alimeeting/modular_sa_asr/path.sh
Executable file
@ -0,0 +1,14 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
export KALDI_ROOT=/Your_Kaldi_root
|
||||
export DATA_SOURCE=/Your_data_path
|
||||
export DATA_NAME=Test_2023_Ali_far
|
||||
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
|
||||
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
|
||||
. $KALDI_ROOT/tools/config/common_path.sh
|
||||
export LC_ALL=C
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
|
||||
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
|
||||
152
egs/alimeeting/modular_sa_asr/run_asr.sh
Executable file
152
egs/alimeeting/modular_sa_asr/run_asr.sh
Executable file
@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||
gpu_num=4
|
||||
count=1
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
finetune=true
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=2
|
||||
train_cmd=utils/run.pl
|
||||
infer_cmd=utils/run.pl
|
||||
|
||||
# general configuration
|
||||
feats_dir="data" #feature output dictionary
|
||||
exp_dir="."
|
||||
lang=zh
|
||||
token_type=char
|
||||
type=sound
|
||||
scp=wav.scp
|
||||
speed_perturb="1.0"
|
||||
stage=0
|
||||
stop_stage=1
|
||||
|
||||
# feature configuration
|
||||
feats_dim=80
|
||||
nj=64
|
||||
|
||||
# exp tag
|
||||
tag="finetune"
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
train_set=Train_Ali_far_wpegss
|
||||
valid_set=Test_Ali_far_wpegss
|
||||
test_sets="${DATA_NAME}_wpegss"
|
||||
|
||||
asr_config=conf/train_paraformer.yaml
|
||||
model_dir="$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
|
||||
pretrain_model_dir=./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
inference_config=$pretrain_model_dir/decoding.yaml
|
||||
|
||||
|
||||
token_list=$pretrain_model_dir/tokens.txt
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
if ${finetune}; then
|
||||
inference_asr_model=./checkpoint/valid.acc.ave.pb
|
||||
finetune_tag="_finetune"
|
||||
else
|
||||
inference_asr_model=$pretrain_model_dir/model.pb
|
||||
finetune_tag=""
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
if [ -L ./utils ]; then
|
||||
unlink ./utils
|
||||
ln -s ../../aishell/transformer/utils
|
||||
else
|
||||
ln -s ../../aishell/transformer/utils
|
||||
fi
|
||||
fi
|
||||
|
||||
# Download Model
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Download Model"
|
||||
if [ ! -d $pretrain_model_dir ]; then
|
||||
git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git
|
||||
fi
|
||||
fi
|
||||
|
||||
# ASR Training Stage
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
|
||||
echo "stage 2: ASR Training"
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node $gpu_num local/finetune.py
|
||||
|
||||
fi
|
||||
|
||||
# Testing Stage
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: Inference"
|
||||
for dset in ${test_sets}; do
|
||||
_dir="$pretrain_model_dir/decode_${dset}${finetune_tag}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
|
||||
exit 0
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="./data/${dset}"
|
||||
key_file=${_data}/${scp}
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
|
||||
--cmvn_file $pretrain_model_dir/am.mvn \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--asr_train_config $pretrain_model_dir/config.yaml \
|
||||
--asr_model_file $inference_asr_model \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode paraformer \
|
||||
${_opts}
|
||||
|
||||
for f in token token_int score text; do
|
||||
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | sort -k1 >"${_dir}/${f}"
|
||||
fi
|
||||
done
|
||||
python local/merge_spk_text.py ${_dir}/text ${_data}/utt2spk
|
||||
python local/compute_cpcer.py ${_data}/text_merge ${_dir}/text_merge
|
||||
echo "cpCER is saved at ${_dir}/text_cpcer"
|
||||
done
|
||||
fi
|
||||
|
||||
233
egs/alimeeting/modular_sa_asr/run_diar.sh
Executable file
233
egs/alimeeting/modular_sa_asr/run_diar.sh
Executable file
@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
. path.sh || exit 1
|
||||
train_cmd=utils/run.pl
|
||||
|
||||
# data path
|
||||
data_source_dir=$DATA_SOURCE
|
||||
textgrid_dir=$data_source_dir/textgrid_dir/
|
||||
wav_dir=$data_source_dir/audio_dir/
|
||||
|
||||
# work path
|
||||
work_dir=./data/${DATA_NAME}_sc/
|
||||
sad_dir=$work_dir/sad_part/
|
||||
sad_work_dir=$sad_dir/exp/
|
||||
sad_result_dir=$sad_dir/sad
|
||||
dia_dir=$work_dir/dia_part/
|
||||
dia_vad_dir=$dia_dir/vad/
|
||||
dia_rttm_dir=$dia_dir/rttm/
|
||||
dia_emb_dir=$dia_dir/embedding/
|
||||
dia_rtt_label_dir=$dia_dir/label_rttm/
|
||||
dia_result_dir=$dia_dir/result_DER/
|
||||
sond_work_dir=./data/${DATA_NAME}_sond/
|
||||
asr_work_dir=./data/${DATA_NAME}_wpegss/org/
|
||||
|
||||
mkdir -p $work_dir || exit 1;
|
||||
mkdir -p $sad_dir || exit 1;
|
||||
mkdir -p $sad_work_dir || exit 1;
|
||||
mkdir -p $sad_result_dir || exit 1;
|
||||
mkdir -p $dia_dir || exit 1;
|
||||
mkdir -p $dia_vad_dir || exit 1;
|
||||
mkdir -p $dia_rttm_dir || exit 1;
|
||||
mkdir -p $dia_emb_dir || exit 1;
|
||||
mkdir -p $dia_rtt_label_dir || exit 1;
|
||||
mkdir -p $dia_result_dir || exit 1;
|
||||
mkdir -p $sond_work_dir || exit 1;
|
||||
mkdir -p $asr_work_dir || exit 1;
|
||||
|
||||
stage=0
|
||||
stop_stage=9
|
||||
nj=4
|
||||
sm_size=83
|
||||
|
||||
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# Check the installtion of kaldi
|
||||
if [ -L ./steps ]; then
|
||||
unlink ./steps
|
||||
else
|
||||
ln -s $KALDI_ROOT/egs/wsj/s5/steps || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
|
||||
fi
|
||||
|
||||
if [ -L ./utils ]; then
|
||||
unlink ./utils
|
||||
else
|
||||
ln -s $KALDI_ROOT/egs/wsj/s5/utils || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# Prepare the AliMeeting data
|
||||
echo "Prepare Alimeeting data"
|
||||
find $wav_dir -name "*\.wav" > $work_dir/wavlist
|
||||
sort $work_dir/wavlist > $work_dir/tmp
|
||||
cp $work_dir/tmp $work_dir/wavlist
|
||||
awk -F '/' '{print $NF}' $work_dir/wavlist | awk -F '.' '{print $1}' > $work_dir/uttid
|
||||
paste -d " " $work_dir/uttid $work_dir/wavlist > $work_dir/wav.scp
|
||||
paste -d " " $work_dir/uttid $work_dir/uttid > $work_dir/utt2spk
|
||||
cp $work_dir/utt2spk $work_dir/spk2utt
|
||||
cp $work_dir/uttid $work_dir/text
|
||||
|
||||
sad_feat=$sad_dir/feat/mfcc
|
||||
cp $work_dir/wav.scp $sad_dir
|
||||
cp $work_dir/utt2spk $sad_dir
|
||||
cp $work_dir/spk2utt $sad_dir
|
||||
cp $work_dir/text $sad_dir
|
||||
|
||||
utils/fix_data_dir.sh $sad_dir
|
||||
|
||||
## first we extract the feature for sad model
|
||||
steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \
|
||||
--mfcc-config conf/mfcc_hires.conf \
|
||||
$sad_dir $sad_dir/make_mfcc $sad_feat
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# Do Speech Activity Detectation
|
||||
echo "Do SAD"
|
||||
./utils/split_data.sh $sad_dir $nj
|
||||
## do the segmentations
|
||||
local/segmentation/detect_speech_activity.sh --nj $nj --stage 0 \
|
||||
--cmd "$train_cmd" $sad_dir exp/segmentation_1a/tdnn_stats_sad_1a/ \
|
||||
$sad_dir/feat/mfcc $sad_work_dir $sad_result_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Do Speaker Embedding Extractor"
|
||||
cp $work_dir/wav.scp $dia_dir
|
||||
|
||||
python local/segment_to_lab.py --input_segments $sad_dir/sad_seg/segments \
|
||||
--label_path $dia_vad_dir \
|
||||
--output_label_scp_file $dia_dir/label.scp ||exit 1;
|
||||
|
||||
./utils/split_data.sh $work_dir $nj
|
||||
${train_cmd} JOB=1:${nj} $dia_dir/exp/extract_embedding.JOB.log \
|
||||
python VBx/predict.py --in-file-list $work_dir/split${nj}/JOB/text \
|
||||
--in-lab-dir $dia_dir/vad \
|
||||
--in-wav-dir $wav_dir \
|
||||
--out-ark-fn $dia_emb_dir/embedding_out.JOB.ark \
|
||||
--out-seg-fn $dia_emb_dir/embedding_out.JOB.seg \
|
||||
--weights VBx/models/ResNet101_16kHz/nnet/final.onnx \
|
||||
--backend onnx
|
||||
|
||||
echo "success"
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# The Speaker Embedding Cluster
|
||||
echo "Do the Speaker Embedding Cluster"
|
||||
# The meeting data is long so that the cluster is a little bit slow
|
||||
${train_cmd} JOB=1:${nj} $dia_dir/exp/cluster.JOB.log \
|
||||
python VBx/vbhmm.py --init AHC+VB \
|
||||
--out-rttm-dir $dia_rttm_dir \
|
||||
--xvec-ark-file $dia_emb_dir/embedding_out.JOB.ark \
|
||||
--segments-file $dia_emb_dir/embedding_out.JOB.seg \
|
||||
--xvec-transform VBx/models/ResNet101_16kHz/transform.h5 \
|
||||
--plda-file VBx/models/ResNet101_16kHz/plda \
|
||||
--threshold 0.14 \
|
||||
--lda-dim 128 \
|
||||
--Fa 0.3 \
|
||||
--Fb 17 \
|
||||
--loopP 0.99
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Process textgrid to obtain rttm label"
|
||||
find -L $textgrid_dir -iname "*.TextGrid" > $work_dir/textgrid.flist
|
||||
sort $work_dir/textgrid.flist > $work_dir/tmp
|
||||
cp $work_dir/tmp $work_dir/textgrid.flist
|
||||
paste $work_dir/uttid $work_dir/textgrid.flist > $work_dir/uttid_textgrid.flist
|
||||
while read text_file
|
||||
do
|
||||
text_grid=`echo $text_file | awk '{print $1}'`
|
||||
text_grid_path=`echo $text_file | awk '{print $2}'`
|
||||
python local/make_textgrid_rttm.py --input_textgrid_file $text_grid_path \
|
||||
--uttid $text_grid \
|
||||
--output_rttm_file $dia_rtt_label_dir/${text_grid}.rttm
|
||||
done < $work_dir/uttid_textgrid.flist
|
||||
if [ -f "$dia_rtt_label_dir/all.rttm" ]; then
|
||||
rm -f $dia_rtt_label_dir/all.rttm
|
||||
fi
|
||||
cat $dia_rtt_label_dir/*.rttm > $dia_rtt_label_dir/all.rttm
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "Get VBx DER result"
|
||||
find $dia_rtt_label_dir -name "*.rttm" > $dia_rtt_label_dir/ref.scp
|
||||
find $dia_rttm_dir -name "*.rttm" > $dia_rttm_dir/sys.scp
|
||||
if [ -f "$dia_rttm_dir/all.rttm" ]; then
|
||||
rm -f $dia_rttm_dir/all.rttm
|
||||
fi
|
||||
cat $dia_rttm_dir/*.rttm > $dia_rttm_dir/all.rttm
|
||||
|
||||
collar_set="0 0.25"
|
||||
python local/meeting_speaker_number_process.py --path=$work_dir \
|
||||
--label_path=$dia_rtt_label_dir --predict_path=$dia_rttm_dir
|
||||
speaker_number="2 3 4"
|
||||
for weight_collar in $collar_set;
|
||||
do
|
||||
# all meeting
|
||||
python dscore/score.py --collar $weight_collar \
|
||||
-R $dia_rtt_label_dir/ref.scp -S $dia_rttm_dir/sys.scp > $dia_result_dir/speaker_all_DER_overlaps_${weight_collar}.log
|
||||
# 2,3,4 speaker meeting
|
||||
for speaker_count in $speaker_number;
|
||||
do
|
||||
python dscore/score.py --collar $weight_collar \
|
||||
-R $dia_rtt_label_dir/speaker${speaker_count}_id -S $dia_rttm_dir/speaker${speaker_count}_id > $dia_result_dir/speaker_${speaker_count}_DER_overlaps_${weight_collar}.log
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "Downloading Pre-trained model..."
|
||||
mkdir ./SOND
|
||||
cd ./SOND
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
|
||||
git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
|
||||
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pb
|
||||
cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
|
||||
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pb
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
|
||||
cd ..
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Prepare data for sond"
|
||||
cp $work_dir/wav.scp $sond_work_dir
|
||||
# convert rttm to segments
|
||||
python local/rttm2segments.py $dia_rttm_dir/all.rttm $sond_work_dir 0
|
||||
# remove the overlapped part
|
||||
python local/remove_overlap.py $sond_work_dir/segments $sond_work_dir/utt2spk \
|
||||
$sond_work_dir/segments_nooverlap $sond_work_dir/utt2spk_nooverlap 0.3
|
||||
# extract speaker profile from the filtered segments file
|
||||
python local/extract_profile_from_segments.py $sond_work_dir
|
||||
# segment data to 16s
|
||||
python local/resegment_data.py \
|
||||
$data_source_dir/segments \
|
||||
$data_source_dir/wav.scp \
|
||||
$sond_work_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
echo "Diarization with SOND"
|
||||
|
||||
python local/infer_sond.py SOND/sond.yaml SOND/sond.pb $sond_work_dir $sond_work_dir/dia_outputs
|
||||
|
||||
python local/convert_label_to_rttm.py \
|
||||
$sond_work_dir/dia_outputs/labels.txt \
|
||||
$sond_work_dir/map.scp \
|
||||
$sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size ${sm_size} \
|
||||
--vote_prob 0.5 --n_spk 16
|
||||
|
||||
python dscore/score.py \
|
||||
-r $dia_rtt_label_dir/all.rttm \
|
||||
-s $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
|
||||
--collar 0.25 &> $sond_work_dir/dia_outputs/dia_result
|
||||
# convert rttm to segments
|
||||
python local/rttm2segments.py $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm $asr_work_dir 1
|
||||
fi
|
||||
|
||||
|
||||
114
egs/alimeeting/modular_sa_asr/run_enh.sh
Executable file
114
egs/alimeeting/modular_sa_asr/run_enh.sh
Executable file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
SECONDS=0
|
||||
|
||||
# general configuration
|
||||
stage=1
|
||||
stop_stage=3
|
||||
nj=10
|
||||
|
||||
log "$0 $*"
|
||||
. utils/parse_options.sh
|
||||
. ./path.sh || exit 1
|
||||
train_cmd=utils/run.pl
|
||||
|
||||
|
||||
data_source_dir=$DATA_SOURCE
|
||||
audio_dir=$data_source_dir/audio_dir
|
||||
output_wpe_dir=$data_source_dir/wpe_audio_dir
|
||||
output_gss_dir=$data_source_dir/gss_audio_dir
|
||||
asr_data_path=./data/${DATA_NAME}_wpegss
|
||||
channel=$1
|
||||
|
||||
log "Start Speech Enhancement."
|
||||
|
||||
if [ ! -L ./utils ]; then
|
||||
ln -s ./pb_chime5/pb_bss
|
||||
fi
|
||||
|
||||
# WPE
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
log "stage 1: Start WPE."
|
||||
for ch in `seq ${channel}`; do
|
||||
mkdir -p ${output_wpe_dir}_${ch}/log/
|
||||
# split wav.scp
|
||||
find $audio_dir/ -name "*.wav" > ${output_wpe_dir}_${ch}/wav.scp
|
||||
arr=""
|
||||
for i in `seq ${nj}`; do
|
||||
arr="$arr ${output_wpe_dir}_${ch}/log/wav.${i}.scp"
|
||||
done
|
||||
split_scp.pl ${output_wpe_dir}_${ch}/wav.scp $arr
|
||||
# do wpe
|
||||
for n in `seq ${nj}`; do
|
||||
cat <<-EOF >${output_wpe_dir}_${ch}/log/wpe.${n}.sh
|
||||
python local/run_wpe.py \
|
||||
--wav-scp ${output_wpe_dir}_${ch}/log/wav.${n}.scp \
|
||||
--audio-dir ${audio_dir} \
|
||||
--output-dir ${output_wpe_dir}_${ch} \
|
||||
--ch $ch
|
||||
EOF
|
||||
done
|
||||
chmod a+x ${output_wpe_dir}_${ch}/log/wpe.*.sh
|
||||
${train_cmd} JOB=1:${nj} ${output_wpe_dir}_${ch}/log/wpe.JOB.log \
|
||||
${output_wpe_dir}_${ch}/log/wpe.JOB.sh
|
||||
done
|
||||
fi
|
||||
|
||||
# GSS
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
log "stage 2: Start GSS"
|
||||
if [ ! -d pb_chime5/ ]; then
|
||||
log "Please install pb_chime5 by local/install_pb_chime5.sh"
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p $output_gss_dir/log
|
||||
# split wpe.scp
|
||||
for i in `seq ${channel}`; do
|
||||
find ${output_wpe_dir}_${i}/ -name "*.wav" > $output_gss_dir/tmp${i}
|
||||
done
|
||||
awk -F '/' '{print($NF)}' $output_gss_dir/tmp1 | cut -d "." -f1 > $output_gss_dir/tmp
|
||||
arr=""
|
||||
for i in `seq ${channel}`; do
|
||||
arr="$arr $output_gss_dir/tmp${i}"
|
||||
done
|
||||
paste -d " " $output_gss_dir/tmp $arr > $output_gss_dir/wpe.scp
|
||||
rm -f $output_gss_dir/tmp*
|
||||
arr=""
|
||||
for i in `seq ${nj}`; do
|
||||
arr="$arr $output_gss_dir/log/wpe.${i}.scp"
|
||||
done
|
||||
split_scp.pl $output_gss_dir/wpe.scp $arr
|
||||
|
||||
# do gss
|
||||
for n in `seq ${nj}`; do
|
||||
cat <<-EOF >${output_gss_dir}/log/gss.${n}.sh
|
||||
python local/run_gss.py \
|
||||
--wav-scp ${output_gss_dir}/log/wpe.${n}.scp \
|
||||
--segments $asr_data_path/org/segments \
|
||||
--output-dir ${output_gss_dir}
|
||||
EOF
|
||||
done
|
||||
chmod a+x ${output_gss_dir}/log/gss.*.sh
|
||||
${train_cmd} JOB=1:${nj} ${output_gss_dir}/log/gss.JOB.log \
|
||||
${output_gss_dir}/log/gss.JOB.sh
|
||||
fi
|
||||
|
||||
# Prepare data for ASR
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "stage 3: Preparing data for ASR"
|
||||
find $output_gss_dir -name "*.wav" > $asr_data_path/org/wav_list
|
||||
awk -F '/' '{print($NF)}' $asr_data_path/org/wav_list | sed 's/\.wav//g' > $asr_data_path/org/uttid
|
||||
paste -d " " $asr_data_path/org/uttid $asr_data_path/org/wav_list > $asr_data_path/org/wav.scp
|
||||
bash local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format wav --segments $asr_data_path/org/segments \
|
||||
"$asr_data_path/org/wav.scp" "$asr_data_path"
|
||||
fi
|
||||
|
||||
log "End speech enhancement"
|
||||
7
setup.py
7
setup.py
@ -19,8 +19,11 @@ requirements = {
|
||||
"soundfile>=0.12.1",
|
||||
"h5py>=2.10.0",
|
||||
"kaldiio>=2.17.0",
|
||||
"kaldi-io==0.9.8",
|
||||
"torch_complex",
|
||||
"nltk>=3.4.5",
|
||||
"onnxruntime"
|
||||
"numexpr"
|
||||
# ASR
|
||||
"sentencepiece",
|
||||
"jieba",
|
||||
@ -32,6 +35,8 @@ requirements = {
|
||||
"editdistance>=0.5.2",
|
||||
"tensorboard",
|
||||
"g2p",
|
||||
"nara_wpe",
|
||||
"Cython",
|
||||
# PAI
|
||||
"oss2",
|
||||
"edit-distance",
|
||||
@ -123,4 +128,4 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
],
|
||||
)
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user