mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
f5bd371837
commit
21536068b9
127
egs/callhome/eend_ola/local/dump_feature.py
Normal file
127
egs/callhome/eend_ola/local/dump_feature.py
Normal file
@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import funasr.modules.eend_ola.utils.feature as feature
|
||||
import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
|
||||
|
||||
|
||||
def _count_frames(data_len, size, step):
|
||||
return int((data_len - size + step) / step)
|
||||
|
||||
|
||||
def _gen_frame_indices(
|
||||
data_length, size=2000, step=2000,
|
||||
use_last_samples=False,
|
||||
label_delay=0,
|
||||
subsampling=1):
|
||||
i = -1
|
||||
for i in range(_count_frames(data_length, size, step)):
|
||||
yield i * step, i * step + size
|
||||
if use_last_samples and i * step + size < data_length:
|
||||
if data_length - (i + 1) * step - subsampling * label_delay > 0:
|
||||
yield (i + 1) * step, data_length
|
||||
|
||||
|
||||
class KaldiDiarizationDataset():
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
chunk_size=2000,
|
||||
context_size=0,
|
||||
frame_size=1024,
|
||||
frame_shift=256,
|
||||
subsampling=1,
|
||||
rate=16000,
|
||||
input_transform=None,
|
||||
use_last_samples=False,
|
||||
label_delay=0,
|
||||
n_speakers=None,
|
||||
):
|
||||
self.data_dir = data_dir
|
||||
self.chunk_size = chunk_size
|
||||
self.context_size = context_size
|
||||
self.frame_size = frame_size
|
||||
self.frame_shift = frame_shift
|
||||
self.subsampling = subsampling
|
||||
self.input_transform = input_transform
|
||||
self.n_speakers = n_speakers
|
||||
self.chunk_indices = []
|
||||
self.label_delay = label_delay
|
||||
|
||||
self.data = kaldi_data.KaldiData(self.data_dir)
|
||||
|
||||
# make chunk indices: filepath, start_frame, end_frame
|
||||
for rec, path in self.data.wavs.items():
|
||||
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
|
||||
data_len = int(data_len / self.subsampling)
|
||||
for st, ed in _gen_frame_indices(
|
||||
data_len, chunk_size, chunk_size, use_last_samples,
|
||||
label_delay=self.label_delay,
|
||||
subsampling=self.subsampling):
|
||||
self.chunk_indices.append(
|
||||
(rec, path, st * self.subsampling, ed * self.subsampling))
|
||||
print(len(self.chunk_indices), " chunks")
|
||||
|
||||
|
||||
def convert(args):
|
||||
f = open(out_wav_file, 'w')
|
||||
dataset = KaldiDiarizationDataset(
|
||||
data_dir=args.data_dir,
|
||||
chunk_size=args.num_frames,
|
||||
context_size=args.context_size,
|
||||
input_transform=args.input_transform,
|
||||
frame_size=args.frame_size,
|
||||
frame_shift=args.frame_shift,
|
||||
subsampling=args.subsampling,
|
||||
rate=8000,
|
||||
use_last_samples=True,
|
||||
)
|
||||
length = len(dataset.chunk_indices)
|
||||
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
|
||||
Y, T = feature.get_labeledSTFT(
|
||||
dataset.data,
|
||||
rec,
|
||||
st,
|
||||
ed,
|
||||
dataset.frame_size,
|
||||
dataset.frame_shift,
|
||||
dataset.n_speakers)
|
||||
Y = feature.transform(Y, dataset.input_transform)
|
||||
Y_spliced = feature.splice(Y, dataset.context_size)
|
||||
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
|
||||
st = '{:0>7d}'.format(st)
|
||||
ed = '{:0>7d}'.format(ed)
|
||||
suffix = '_' + st + '_' + ed
|
||||
|
||||
parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
|
||||
# print('parts: ', parts)
|
||||
parts = parts[:4] + ['numpy_data'] + parts[4:]
|
||||
cur_path = '/'.join(parts)
|
||||
# print('cur path: ', cur_path)
|
||||
out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
|
||||
# print(out_path)
|
||||
# print(cur_path)
|
||||
if not os.path.exists(cur_path):
|
||||
os.makedirs(cur_path)
|
||||
np.savez(out_path, Y=Y_ss, T=T_ss)
|
||||
if idx == length - 1:
|
||||
f.write(rec + suffix + ' ' + out_path)
|
||||
else:
|
||||
f.write(rec + suffix + ' ' + out_path + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("data_dir")
|
||||
parser.add_argument("num_frames")
|
||||
parser.add_argument("context_size")
|
||||
parser.add_argument("frame_size")
|
||||
parser.add_argument("frame_shift")
|
||||
parser.add_argument("subsampling")
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
117
egs/callhome/eend_ola/local/split.py
Normal file
117
egs/callhome/eend_ola/local/split.py
Normal file
@ -0,0 +1,117 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('root_path', help='raw data path')
|
||||
args = parser.parse_args()
|
||||
|
||||
root_path = args.root_path
|
||||
work_path = os.path.join(root_path, ".work")
|
||||
scp_files = os.listdir(work_path)
|
||||
|
||||
reco2dur_dict = {}
|
||||
with open(root_path + 'reco2dur') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
reco2dur_dict[parts[0]] = parts[1]
|
||||
|
||||
spk2utt_dict = {}
|
||||
with open(root_path + 'spk2utt') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
spk = parts[0]
|
||||
utts = parts[1:]
|
||||
for utt in utts:
|
||||
tmp = utt.split('data')
|
||||
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
|
||||
if rec in spk2utt_dict.keys():
|
||||
spk2utt_dict[rec].append((spk, utt))
|
||||
else:
|
||||
spk2utt_dict[rec] = []
|
||||
spk2utt_dict[rec].append((spk, utt))
|
||||
|
||||
segment_dict = {}
|
||||
with open(root_path + 'segments') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if parts[1] in segment_dict.keys():
|
||||
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
|
||||
else:
|
||||
segment_dict[parts[1]] = []
|
||||
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
|
||||
|
||||
utt2spk_dict = {}
|
||||
with open(root_path + 'utt2spk') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
utt = parts[0]
|
||||
tmp = utt.split('data')
|
||||
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
|
||||
if rec in utt2spk_dict.keys():
|
||||
utt2spk_dict[rec].append((parts[0], parts[1]))
|
||||
else:
|
||||
utt2spk_dict[rec] = []
|
||||
utt2spk_dict[rec].append((parts[0], parts[1]))
|
||||
|
||||
for file in scp_files:
|
||||
scp_file = work_path + file
|
||||
idx = scp_file.split('.')[-2]
|
||||
reco2dur_file = work_path + 'reco2dur.' + idx
|
||||
spk2utt_file = work_path + 'spk2utt.' + idx
|
||||
segment_file = work_path + 'segments.' + idx
|
||||
utt2spk_file = work_path + 'utt2spk.' + idx
|
||||
|
||||
fpp = open(scp_file)
|
||||
scp_lines = fpp.readlines()
|
||||
keys = []
|
||||
for line in scp_lines:
|
||||
name = line.strip().split()[0]
|
||||
keys.append(name)
|
||||
|
||||
with open(reco2dur_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
string = key + ' ' + reco2dur_dict[key]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(spk2utt_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = spk2utt_dict[key]
|
||||
for item in items:
|
||||
string = item[0]
|
||||
for it in item[1:]:
|
||||
string += ' '
|
||||
string += it
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(segment_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = segment_dict[key]
|
||||
for item in items:
|
||||
string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
with open(utt2spk_file, 'w') as f:
|
||||
lines = []
|
||||
for key in keys:
|
||||
items = utt2spk_dict[key]
|
||||
for item in items:
|
||||
string = item[0] + ' ' + item[1]
|
||||
lines.append(string + '\n')
|
||||
lines[-1] = lines[-1][:-1]
|
||||
f.writelines(lines)
|
||||
|
||||
fpp.close()
|
||||
@ -8,6 +8,11 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
count=1
|
||||
|
||||
# general configuration
|
||||
dump_cmd=utils/run.pl
|
||||
nj=64
|
||||
|
||||
# feature configuration
|
||||
data_dir="./data"
|
||||
simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data"
|
||||
simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data"
|
||||
callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data"
|
||||
@ -62,13 +67,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
local/run_prepare_shared_eda.sh
|
||||
fi
|
||||
|
||||
## Prepare data for training and inference
|
||||
#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# echo "stage 0: Prepare data for training and inference"
|
||||
# echo "The detail can be found in https://github.com/hitachi-speech/EEND"
|
||||
# . ./local/
|
||||
#fi
|
||||
#
|
||||
# Prepare data for training and inference
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "stage 0: Prepare data for training and inference"
|
||||
simu_opts_num_speaker_array=(1 2 3 4)
|
||||
simu_opts_sil_scale_array=(2 2 5 9)
|
||||
simu_opts_num_speaker=${simu_opts_num_speaker_array[i]}
|
||||
simu_opts_sil_scale=${simu_opts_sil_scale_array[i]}
|
||||
simu_opts_num_train=100000
|
||||
|
||||
# for simulated data of chunk500
|
||||
for dset in swb_sre_tr swb_sre_cv; do
|
||||
if [ "$dset" == "swb_sre_tr" ]; then
|
||||
n_mixtures=${simu_opts_num_train}
|
||||
else
|
||||
n_mixtures=500
|
||||
fi
|
||||
simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures}
|
||||
mkdir ${data_dir}/simu/data/${simu_data_dir}/.work
|
||||
split_scps=
|
||||
for n in $(seq $nj); do
|
||||
split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
|
||||
done
|
||||
utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
|
||||
python local/split.py ${data_dir}/simu/data/${simu_data_dir}
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
# Training on simulated two-speaker data
|
||||
world_size=$gpu_num
|
||||
|
||||
286
funasr/modules/eend_ola/utils/feature.py
Normal file
286
funasr/modules/eend_ola/utils/feature.py
Normal file
@ -0,0 +1,286 @@
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This module is for computing audio features
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
|
||||
def get_input_dim(
|
||||
frame_size,
|
||||
context_size,
|
||||
transform_type,
|
||||
):
|
||||
if transform_type.startswith('logmel23'):
|
||||
frame_size = 23
|
||||
elif transform_type.startswith('logmel'):
|
||||
frame_size = 40
|
||||
else:
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
frame_size = int(fft_size / 2) + 1
|
||||
input_dim = (2 * context_size + 1) * frame_size
|
||||
return input_dim
|
||||
|
||||
|
||||
def transform(
|
||||
Y,
|
||||
transform_type=None,
|
||||
dtype=np.float32):
|
||||
""" Transform STFT feature
|
||||
|
||||
Args:
|
||||
Y: STFT
|
||||
(n_frames, n_bins)-shaped np.complex array
|
||||
transform_type:
|
||||
None, "log"
|
||||
dtype: output data type
|
||||
np.float32 is expected
|
||||
Returns:
|
||||
Y (numpy.array): transformed feature
|
||||
"""
|
||||
Y = np.abs(Y)
|
||||
if not transform_type:
|
||||
pass
|
||||
elif transform_type == 'log':
|
||||
Y = np.log(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 16000
|
||||
n_mels = 40
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel23':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
elif transform_type == 'logmel23_mn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
elif transform_type == 'logmel23_swn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
# b = np.ones(300)/300
|
||||
# mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
|
||||
|
||||
# simple 2-means based threshoding for mean calculation
|
||||
powers = np.sum(Y, axis=1)
|
||||
th = (np.max(powers) + np.min(powers)) / 2.0
|
||||
for i in range(10):
|
||||
th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
|
||||
mean = np.mean(Y[powers > th, :], axis=0)
|
||||
Y = Y - mean
|
||||
elif transform_type == 'logmel23_mvn':
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
std = np.maximum(np.std(Y, axis=0), 1e-10)
|
||||
Y = Y / std
|
||||
else:
|
||||
raise ValueError('Unknown transform_type: %s' % transform_type)
|
||||
return Y.astype(dtype)
|
||||
|
||||
|
||||
def subsample(Y, T, subsampling=1):
|
||||
""" Frame subsampling
|
||||
"""
|
||||
Y_ss = Y[::subsampling]
|
||||
T_ss = T[::subsampling]
|
||||
return Y_ss, T_ss
|
||||
|
||||
|
||||
def splice(Y, context_size=0):
|
||||
""" Frame splicing
|
||||
|
||||
Args:
|
||||
Y: feature
|
||||
(n_frames, n_featdim)-shaped numpy array
|
||||
context_size:
|
||||
number of frames concatenated on left-side
|
||||
if context_size = 5, 11 frames are concatenated.
|
||||
|
||||
Returns:
|
||||
Y_spliced: spliced feature
|
||||
(n_frames, n_featdim * (2 * context_size + 1))-shaped
|
||||
"""
|
||||
Y_pad = np.pad(
|
||||
Y,
|
||||
[(context_size, context_size), (0, 0)],
|
||||
'constant')
|
||||
Y_spliced = np.lib.stride_tricks.as_strided(
|
||||
np.ascontiguousarray(Y_pad),
|
||||
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
||||
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
|
||||
return Y_spliced
|
||||
|
||||
|
||||
def stft(
|
||||
data,
|
||||
frame_size=1024,
|
||||
frame_shift=256):
|
||||
""" Compute STFT features
|
||||
|
||||
Args:
|
||||
data: audio signal
|
||||
(n_samples,)-shaped np.float32 array
|
||||
frame_size: number of samples in a frame (must be a power of two)
|
||||
frame_shift: number of samples between frames
|
||||
|
||||
Returns:
|
||||
stft: STFT frames
|
||||
(n_frames, n_bins)-shaped np.complex64 array
|
||||
"""
|
||||
# round up to nearest power of 2
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
# HACK: The last frame is ommited
|
||||
# as librosa.stft produces such an excessive frame
|
||||
if len(data) % frame_shift == 0:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T[:-1]
|
||||
else:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T
|
||||
|
||||
|
||||
def _count_frames(data_len, size, shift):
|
||||
# HACK: Assuming librosa.stft(..., center=True)
|
||||
n_frames = 1 + int(data_len / shift)
|
||||
if data_len % shift == 0:
|
||||
n_frames = n_frames - 1
|
||||
return n_frames
|
||||
|
||||
|
||||
def get_frame_labels(
|
||||
kaldi_obj,
|
||||
rec,
|
||||
start=0,
|
||||
end=None,
|
||||
frame_size=1024,
|
||||
frame_shift=256,
|
||||
n_speakers=None):
|
||||
""" Get frame-aligned labels of given recording
|
||||
Args:
|
||||
kaldi_obj (KaldiData)
|
||||
rec (str): recording id
|
||||
start (int): start frame index
|
||||
end (int): end frame index
|
||||
None means the last frame of recording
|
||||
frame_size (int): number of frames in a frame
|
||||
frame_shift (int): number of shift samples
|
||||
n_speakers (int): number of speakers
|
||||
if None, the value is given from data
|
||||
Returns:
|
||||
T: label
|
||||
(n_frames, n_speakers)-shaped np.int32 array
|
||||
"""
|
||||
filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
|
||||
speakers = np.unique(
|
||||
[kaldi_obj.utt2spk[seg['utt']] for seg
|
||||
in filtered_segments]).tolist()
|
||||
if n_speakers is None:
|
||||
n_speakers = len(speakers)
|
||||
es = end * frame_shift if end is not None else None
|
||||
data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es)
|
||||
n_frames = _count_frames(len(data), frame_size, frame_shift)
|
||||
T = np.zeros((n_frames, n_speakers), dtype=np.int32)
|
||||
if end is None:
|
||||
end = n_frames
|
||||
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
start_frame = np.rint(
|
||||
seg['st'] * rate / frame_shift).astype(int)
|
||||
end_frame = np.rint(
|
||||
seg['et'] * rate / frame_shift).astype(int)
|
||||
rel_start = rel_end = None
|
||||
if start <= start_frame and start_frame < end:
|
||||
rel_start = start_frame - start
|
||||
if start < end_frame and end_frame <= end:
|
||||
rel_end = end_frame - start
|
||||
if rel_start is not None or rel_end is not None:
|
||||
T[rel_start:rel_end, speaker_index] = 1
|
||||
return T
|
||||
|
||||
|
||||
def get_labeledSTFT(
|
||||
kaldi_obj,
|
||||
rec, start, end, frame_size, frame_shift,
|
||||
n_speakers=None,
|
||||
use_speaker_id=False):
|
||||
""" Extracts STFT and corresponding labels
|
||||
|
||||
Extracts STFT and corresponding diarization labels for
|
||||
given recording id and start/end times
|
||||
|
||||
Args:
|
||||
kaldi_obj (KaldiData)
|
||||
rec (str): recording id
|
||||
start (int): start frame index
|
||||
end (int): end frame index
|
||||
frame_size (int): number of samples in a frame
|
||||
frame_shift (int): number of shift samples
|
||||
n_speakers (int): number of speakers
|
||||
if None, the value is given from data
|
||||
Returns:
|
||||
Y: STFT
|
||||
(n_frames, n_bins)-shaped np.complex64 array,
|
||||
T: label
|
||||
(n_frmaes, n_speakers)-shaped np.int32 array.
|
||||
"""
|
||||
data, rate = kaldi_obj.load_wav(
|
||||
rec, start * frame_shift, end * frame_shift)
|
||||
Y = stft(data, frame_size, frame_shift)
|
||||
filtered_segments = kaldi_obj.segments[rec]
|
||||
# filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
|
||||
speakers = np.unique(
|
||||
[kaldi_obj.utt2spk[seg['utt']] for seg
|
||||
in filtered_segments]).tolist()
|
||||
if n_speakers is None:
|
||||
n_speakers = len(speakers)
|
||||
T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
|
||||
|
||||
if use_speaker_id:
|
||||
all_speakers = sorted(kaldi_obj.spk2utt.keys())
|
||||
S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
|
||||
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
if use_speaker_id:
|
||||
all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
|
||||
start_frame = np.rint(
|
||||
seg['st'] * rate / frame_shift).astype(int)
|
||||
end_frame = np.rint(
|
||||
seg['et'] * rate / frame_shift).astype(int)
|
||||
rel_start = rel_end = None
|
||||
if start <= start_frame and start_frame < end:
|
||||
rel_start = start_frame - start
|
||||
if start < end_frame and end_frame <= end:
|
||||
rel_end = end_frame - start
|
||||
if rel_start is not None or rel_end is not None:
|
||||
T[rel_start:rel_end, speaker_index] = 1
|
||||
if use_speaker_id:
|
||||
S[rel_start:rel_end, all_speaker_index] = 1
|
||||
|
||||
if use_speaker_id:
|
||||
return Y, T, S
|
||||
else:
|
||||
return Y, T
|
||||
Loading…
Reference in New Issue
Block a user