diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py index 169615e1b..332edd2a1 100644 --- a/egs/callhome/eend_ola/local/dump_feature.py +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -1,10 +1,11 @@ import argparse import os -import numpy as np +from kaldiio import WriteHelper import funasr.modules.eend_ola.utils.feature as feature -import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data +from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \ + load_spk2utt, load_wav def _count_frames(data_len, size, step): @@ -24,10 +25,34 @@ def _gen_frame_indices( yield (i + 1) * step, data_length +class KaldiData: + def __init__(self, data_dir, idx): + self.data_dir = data_dir + segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx)) + self.segments = load_segments_rechash(segment_file) + + utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx)) + self.utt2spk = load_utt2spk(utt2spk_file) + + wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx)) + self.wavs = load_wav_scp(wav_file) + + reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx)) + self.reco2dur = load_reco2dur(reco2dur_file) + + spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx)) + self.spk2utt = load_spk2utt(spk2utt_file) + + def load_wav(self, recid, start=0, end=None): + data, rate = load_wav(self.wavs[recid], start, end) + return data, rate + + class KaldiDiarizationDataset(): def __init__( self, data_dir, + index, chunk_size=2000, context_size=0, frame_size=1024, @@ -40,6 +65,7 @@ class KaldiDiarizationDataset(): n_speakers=None, ): self.data_dir = data_dir + self.index = index self.chunk_size = chunk_size self.context_size = context_size self.frame_size = frame_size @@ -50,9 +76,8 @@ class KaldiDiarizationDataset(): self.chunk_indices = [] self.label_delay = label_delay - self.data = kaldi_data.KaldiData(self.data_dir) + self.data = KaldiData(self.data_dir, index) - # make chunk indices: filepath, start_frame, end_frame for rec, path in self.data.wavs.items(): data_len = int(self.data.reco2dur[rec] * rate / frame_shift) data_len = int(data_len / self.subsampling) @@ -66,62 +91,54 @@ class KaldiDiarizationDataset(): def convert(args): - f = open(out_wav_file, 'w') dataset = KaldiDiarizationDataset( data_dir=args.data_dir, + index=args.index, chunk_size=args.num_frames, context_size=args.context_size, - input_transform=args.input_transform, + input_transform="logmel23_mn", frame_size=args.frame_size, frame_shift=args.frame_shift, subsampling=args.subsampling, rate=8000, use_last_samples=True, ) - length = len(dataset.chunk_indices) - for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): - Y, T = feature.get_labeledSTFT( - dataset.data, - rec, - st, - ed, - dataset.frame_size, - dataset.frame_shift, - dataset.n_speakers) - Y = feature.transform(Y, dataset.input_transform) - Y_spliced = feature.splice(Y, dataset.context_size) - Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) - st = '{:0>7d}'.format(st) - ed = '{:0>7d}'.format(ed) - suffix = '_' + st + '_' + ed - parts = os.readlink('/'.join(path.split('/')[:-1])).split('/') - # print('parts: ', parts) - parts = parts[:4] + ['numpy_data'] + parts[4:] - cur_path = '/'.join(parts) - # print('cur path: ', cur_path) - out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz') - # print(out_path) - # print(cur_path) - if not os.path.exists(cur_path): - os.makedirs(cur_path) - np.savez(out_path, Y=Y_ss, T=T_ss) - if idx == length - 1: - f.write(rec + suffix + ' ' + out_path) - else: - f.write(rec + suffix + ' ' + out_path + '\n') + feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index)) + feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index)) + label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index)) + label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index)) + with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \ + WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer: + for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): + Y, T = feature.get_labeledSTFT( + dataset.data, + rec, + st, + ed, + dataset.frame_size, + dataset.frame_shift, + dataset.n_speakers) + Y = feature.transform(Y, dataset.input_transform) + Y_spliced = feature.splice(Y, dataset.context_size) + Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) + st = '{:0>7d}'.format(st) + ed = '{:0>7d}'.format(ed) + key = "{}_{}_{}".format(rec, st, ed) + feature_writer(key, Y_ss) + label_writer(key, T_ss.reshape(-1)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("data_dir") - parser.add_argument("num_frames") - parser.add_argument("context_size") - parser.add_argument("frame_size") - parser.add_argument("frame_shift") - parser.add_argument("subsampling") - - + parser.add_argument("output_dir") + parser.add_argument("index") + parser.add_argument("num_frames", default=500) + parser.add_argument("context_size", default=7) + parser.add_argument("frame_size", default=200) + parser.add_argument("frame_shift", default=80) + parser.add_argument("subsampling", default=10) args = parser.parse_args() convert(args) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 8ba8d5742..c6a3a7109 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -78,17 +78,26 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then for dset in swb_sre_tr swb_sre_cv; do if [ "$dset" == "swb_sre_tr" ]; then n_mixtures=${simu_opts_num_train} + dataset=train else n_mixtures=500 + dataset=dev fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} - mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work - split_scps= - for n in $(seq $nj); do - split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" - done - utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 - python local/split.py ${data_dir}/simu/data/${simu_data_dir} +# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work +# split_scps= +# for n in $(seq $nj); do +# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" +# done +# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 +# python local/split.py ${data_dir}/simu/data/${simu_data_dir} + output_dir=${data_dir}/ark_data/dump/simu_data/$dataset + mkdir -p $output_dir/.logs + $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ + python local/dump_feature.py \ + --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ + --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ + --index JOB done fi