mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
fa618d7634
commit
0109889f1c
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
from kaldiio import WriteHelper
|
||||||
|
|
||||||
import funasr.modules.eend_ola.utils.feature as feature
|
import funasr.modules.eend_ola.utils.feature as feature
|
||||||
import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
|
from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \
|
||||||
|
load_spk2utt, load_wav
|
||||||
|
|
||||||
|
|
||||||
def _count_frames(data_len, size, step):
|
def _count_frames(data_len, size, step):
|
||||||
@ -24,10 +25,34 @@ def _gen_frame_indices(
|
|||||||
yield (i + 1) * step, data_length
|
yield (i + 1) * step, data_length
|
||||||
|
|
||||||
|
|
||||||
|
class KaldiData:
|
||||||
|
def __init__(self, data_dir, idx):
|
||||||
|
self.data_dir = data_dir
|
||||||
|
segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx))
|
||||||
|
self.segments = load_segments_rechash(segment_file)
|
||||||
|
|
||||||
|
utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx))
|
||||||
|
self.utt2spk = load_utt2spk(utt2spk_file)
|
||||||
|
|
||||||
|
wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx))
|
||||||
|
self.wavs = load_wav_scp(wav_file)
|
||||||
|
|
||||||
|
reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx))
|
||||||
|
self.reco2dur = load_reco2dur(reco2dur_file)
|
||||||
|
|
||||||
|
spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx))
|
||||||
|
self.spk2utt = load_spk2utt(spk2utt_file)
|
||||||
|
|
||||||
|
def load_wav(self, recid, start=0, end=None):
|
||||||
|
data, rate = load_wav(self.wavs[recid], start, end)
|
||||||
|
return data, rate
|
||||||
|
|
||||||
|
|
||||||
class KaldiDiarizationDataset():
|
class KaldiDiarizationDataset():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_dir,
|
data_dir,
|
||||||
|
index,
|
||||||
chunk_size=2000,
|
chunk_size=2000,
|
||||||
context_size=0,
|
context_size=0,
|
||||||
frame_size=1024,
|
frame_size=1024,
|
||||||
@ -40,6 +65,7 @@ class KaldiDiarizationDataset():
|
|||||||
n_speakers=None,
|
n_speakers=None,
|
||||||
):
|
):
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
|
self.index = index
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
self.frame_size = frame_size
|
self.frame_size = frame_size
|
||||||
@ -50,9 +76,8 @@ class KaldiDiarizationDataset():
|
|||||||
self.chunk_indices = []
|
self.chunk_indices = []
|
||||||
self.label_delay = label_delay
|
self.label_delay = label_delay
|
||||||
|
|
||||||
self.data = kaldi_data.KaldiData(self.data_dir)
|
self.data = KaldiData(self.data_dir, index)
|
||||||
|
|
||||||
# make chunk indices: filepath, start_frame, end_frame
|
|
||||||
for rec, path in self.data.wavs.items():
|
for rec, path in self.data.wavs.items():
|
||||||
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
|
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
|
||||||
data_len = int(data_len / self.subsampling)
|
data_len = int(data_len / self.subsampling)
|
||||||
@ -66,19 +91,25 @@ class KaldiDiarizationDataset():
|
|||||||
|
|
||||||
|
|
||||||
def convert(args):
|
def convert(args):
|
||||||
f = open(out_wav_file, 'w')
|
|
||||||
dataset = KaldiDiarizationDataset(
|
dataset = KaldiDiarizationDataset(
|
||||||
data_dir=args.data_dir,
|
data_dir=args.data_dir,
|
||||||
|
index=args.index,
|
||||||
chunk_size=args.num_frames,
|
chunk_size=args.num_frames,
|
||||||
context_size=args.context_size,
|
context_size=args.context_size,
|
||||||
input_transform=args.input_transform,
|
input_transform="logmel23_mn",
|
||||||
frame_size=args.frame_size,
|
frame_size=args.frame_size,
|
||||||
frame_shift=args.frame_shift,
|
frame_shift=args.frame_shift,
|
||||||
subsampling=args.subsampling,
|
subsampling=args.subsampling,
|
||||||
rate=8000,
|
rate=8000,
|
||||||
use_last_samples=True,
|
use_last_samples=True,
|
||||||
)
|
)
|
||||||
length = len(dataset.chunk_indices)
|
|
||||||
|
feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index))
|
||||||
|
feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index))
|
||||||
|
label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index))
|
||||||
|
label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index))
|
||||||
|
with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \
|
||||||
|
WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer:
|
||||||
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
|
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
|
||||||
Y, T = feature.get_labeledSTFT(
|
Y, T = feature.get_labeledSTFT(
|
||||||
dataset.data,
|
dataset.data,
|
||||||
@ -93,35 +124,21 @@ def convert(args):
|
|||||||
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
|
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
|
||||||
st = '{:0>7d}'.format(st)
|
st = '{:0>7d}'.format(st)
|
||||||
ed = '{:0>7d}'.format(ed)
|
ed = '{:0>7d}'.format(ed)
|
||||||
suffix = '_' + st + '_' + ed
|
key = "{}_{}_{}".format(rec, st, ed)
|
||||||
|
feature_writer(key, Y_ss)
|
||||||
parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
|
label_writer(key, T_ss.reshape(-1))
|
||||||
# print('parts: ', parts)
|
|
||||||
parts = parts[:4] + ['numpy_data'] + parts[4:]
|
|
||||||
cur_path = '/'.join(parts)
|
|
||||||
# print('cur path: ', cur_path)
|
|
||||||
out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
|
|
||||||
# print(out_path)
|
|
||||||
# print(cur_path)
|
|
||||||
if not os.path.exists(cur_path):
|
|
||||||
os.makedirs(cur_path)
|
|
||||||
np.savez(out_path, Y=Y_ss, T=T_ss)
|
|
||||||
if idx == length - 1:
|
|
||||||
f.write(rec + suffix + ' ' + out_path)
|
|
||||||
else:
|
|
||||||
f.write(rec + suffix + ' ' + out_path + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("data_dir")
|
parser.add_argument("data_dir")
|
||||||
parser.add_argument("num_frames")
|
parser.add_argument("output_dir")
|
||||||
parser.add_argument("context_size")
|
parser.add_argument("index")
|
||||||
parser.add_argument("frame_size")
|
parser.add_argument("num_frames", default=500)
|
||||||
parser.add_argument("frame_shift")
|
parser.add_argument("context_size", default=7)
|
||||||
parser.add_argument("subsampling")
|
parser.add_argument("frame_size", default=200)
|
||||||
|
parser.add_argument("frame_shift", default=80)
|
||||||
|
parser.add_argument("subsampling", default=10)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args)
|
convert(args)
|
||||||
|
|||||||
@ -78,17 +78,26 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
|||||||
for dset in swb_sre_tr swb_sre_cv; do
|
for dset in swb_sre_tr swb_sre_cv; do
|
||||||
if [ "$dset" == "swb_sre_tr" ]; then
|
if [ "$dset" == "swb_sre_tr" ]; then
|
||||||
n_mixtures=${simu_opts_num_train}
|
n_mixtures=${simu_opts_num_train}
|
||||||
|
dataset=train
|
||||||
else
|
else
|
||||||
n_mixtures=500
|
n_mixtures=500
|
||||||
|
dataset=dev
|
||||||
fi
|
fi
|
||||||
simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
|
simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
|
||||||
mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
|
# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
|
||||||
split_scps=
|
# split_scps=
|
||||||
for n in $(seq $nj); do
|
# for n in $(seq $nj); do
|
||||||
split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
|
# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
|
||||||
done
|
# done
|
||||||
utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
|
# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
|
||||||
python local/split.py ${data_dir}/simu/data/${simu_data_dir}
|
# python local/split.py ${data_dir}/simu/data/${simu_data_dir}
|
||||||
|
output_dir=${data_dir}/ark_data/dump/simu_data/$dataset
|
||||||
|
mkdir -p $output_dir/.logs
|
||||||
|
$dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
|
||||||
|
python local/dump_feature.py \
|
||||||
|
--data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
|
||||||
|
--output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \
|
||||||
|
--index JOB
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user