From 21536068b9e1d94a3c0de09b6b166a786f98361f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 17:09:45 +0800 Subject: [PATCH] update --- egs/callhome/eend_ola/local/dump_feature.py | 127 +++++++++ egs/callhome/eend_ola/local/split.py | 117 ++++++++ egs/callhome/eend_ola/run.sh | 39 ++- funasr/modules/eend_ola/utils/feature.py | 286 ++++++++++++++++++++ 4 files changed, 562 insertions(+), 7 deletions(-) create mode 100644 egs/callhome/eend_ola/local/dump_feature.py create mode 100644 egs/callhome/eend_ola/local/split.py create mode 100644 funasr/modules/eend_ola/utils/feature.py diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py new file mode 100644 index 000000000..169615e1b --- /dev/null +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -0,0 +1,127 @@ +import argparse +import os + +import numpy as np + +import funasr.modules.eend_ola.utils.feature as feature +import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data + + +def _count_frames(data_len, size, step): + return int((data_len - size + step) / step) + + +def _gen_frame_indices( + data_length, size=2000, step=2000, + use_last_samples=False, + label_delay=0, + subsampling=1): + i = -1 + for i in range(_count_frames(data_length, size, step)): + yield i * step, i * step + size + if use_last_samples and i * step + size < data_length: + if data_length - (i + 1) * step - subsampling * label_delay > 0: + yield (i + 1) * step, data_length + + +class KaldiDiarizationDataset(): + def __init__( + self, + data_dir, + chunk_size=2000, + context_size=0, + frame_size=1024, + frame_shift=256, + subsampling=1, + rate=16000, + input_transform=None, + use_last_samples=False, + label_delay=0, + n_speakers=None, + ): + self.data_dir = data_dir + self.chunk_size = chunk_size + self.context_size = context_size + self.frame_size = frame_size + self.frame_shift = frame_shift + self.subsampling = subsampling + self.input_transform = input_transform + self.n_speakers = n_speakers + self.chunk_indices = [] + self.label_delay = label_delay + + self.data = kaldi_data.KaldiData(self.data_dir) + + # make chunk indices: filepath, start_frame, end_frame + for rec, path in self.data.wavs.items(): + data_len = int(self.data.reco2dur[rec] * rate / frame_shift) + data_len = int(data_len / self.subsampling) + for st, ed in _gen_frame_indices( + data_len, chunk_size, chunk_size, use_last_samples, + label_delay=self.label_delay, + subsampling=self.subsampling): + self.chunk_indices.append( + (rec, path, st * self.subsampling, ed * self.subsampling)) + print(len(self.chunk_indices), " chunks") + + +def convert(args): + f = open(out_wav_file, 'w') + dataset = KaldiDiarizationDataset( + data_dir=args.data_dir, + chunk_size=args.num_frames, + context_size=args.context_size, + input_transform=args.input_transform, + frame_size=args.frame_size, + frame_shift=args.frame_shift, + subsampling=args.subsampling, + rate=8000, + use_last_samples=True, + ) + length = len(dataset.chunk_indices) + for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): + Y, T = feature.get_labeledSTFT( + dataset.data, + rec, + st, + ed, + dataset.frame_size, + dataset.frame_shift, + dataset.n_speakers) + Y = feature.transform(Y, dataset.input_transform) + Y_spliced = feature.splice(Y, dataset.context_size) + Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) + st = '{:0>7d}'.format(st) + ed = '{:0>7d}'.format(ed) + suffix = '_' + st + '_' + ed + + parts = os.readlink('/'.join(path.split('/')[:-1])).split('/') + # print('parts: ', parts) + parts = parts[:4] + ['numpy_data'] + parts[4:] + cur_path = '/'.join(parts) + # print('cur path: ', cur_path) + out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz') + # print(out_path) + # print(cur_path) + if not os.path.exists(cur_path): + os.makedirs(cur_path) + np.savez(out_path, Y=Y_ss, T=T_ss) + if idx == length - 1: + f.write(rec + suffix + ' ' + out_path) + else: + f.write(rec + suffix + ' ' + out_path + '\n') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("data_dir") + parser.add_argument("num_frames") + parser.add_argument("context_size") + parser.add_argument("frame_size") + parser.add_argument("frame_shift") + parser.add_argument("subsampling") + + + + args = parser.parse_args() + convert(args) diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py new file mode 100644 index 000000000..6f313ccd4 --- /dev/null +++ b/egs/callhome/eend_ola/local/split.py @@ -0,0 +1,117 @@ +import argparse +import os + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('root_path', help='raw data path') + args = parser.parse_args() + + root_path = args.root_path + work_path = os.path.join(root_path, ".work") + scp_files = os.listdir(work_path) + + reco2dur_dict = {} + with open(root_path + 'reco2dur') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + reco2dur_dict[parts[0]] = parts[1] + + spk2utt_dict = {} + with open(root_path + 'spk2utt') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + spk = parts[0] + utts = parts[1:] + for utt in utts: + tmp = utt.split('data') + rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2]) + if rec in spk2utt_dict.keys(): + spk2utt_dict[rec].append((spk, utt)) + else: + spk2utt_dict[rec] = [] + spk2utt_dict[rec].append((spk, utt)) + + segment_dict = {} + with open(root_path + 'segments') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + if parts[1] in segment_dict.keys(): + segment_dict[parts[1]].append((parts[0], parts[2], parts[3])) + else: + segment_dict[parts[1]] = [] + segment_dict[parts[1]].append((parts[0], parts[2], parts[3])) + + utt2spk_dict = {} + with open(root_path + 'utt2spk') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + utt = parts[0] + tmp = utt.split('data') + rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2]) + if rec in utt2spk_dict.keys(): + utt2spk_dict[rec].append((parts[0], parts[1])) + else: + utt2spk_dict[rec] = [] + utt2spk_dict[rec].append((parts[0], parts[1])) + + for file in scp_files: + scp_file = work_path + file + idx = scp_file.split('.')[-2] + reco2dur_file = work_path + 'reco2dur.' + idx + spk2utt_file = work_path + 'spk2utt.' + idx + segment_file = work_path + 'segments.' + idx + utt2spk_file = work_path + 'utt2spk.' + idx + + fpp = open(scp_file) + scp_lines = fpp.readlines() + keys = [] + for line in scp_lines: + name = line.strip().split()[0] + keys.append(name) + + with open(reco2dur_file, 'w') as f: + lines = [] + for key in keys: + string = key + ' ' + reco2dur_dict[key] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(spk2utt_file, 'w') as f: + lines = [] + for key in keys: + items = spk2utt_dict[key] + for item in items: + string = item[0] + for it in item[1:]: + string += ' ' + string += it + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(segment_file, 'w') as f: + lines = [] + for key in keys: + items = segment_dict[key] + for item in items: + string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(utt2spk_file, 'w') as f: + lines = [] + for key in keys: + items = utt2spk_dict[key] + for item in items: + string = item[0] + ' ' + item[1] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + fpp.close() diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 40fb04113..cd246feee 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -8,6 +8,11 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') count=1 # general configuration +dump_cmd=utils/run.pl +nj=64 + +# feature configuration +data_dir="./data" simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" @@ -62,13 +67,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then local/run_prepare_shared_eda.sh fi -## Prepare data for training and inference -#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then -# echo "stage 0: Prepare data for training and inference" -# echo "The detail can be found in https://github.com/hitachi-speech/EEND" -# . ./local/ -#fi -# +# Prepare data for training and inference +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Prepare data for training and inference" + simu_opts_num_speaker_array=(1 2 3 4) + simu_opts_sil_scale_array=(2 2 5 9) + simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} + simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} + simu_opts_num_train=100000 + + # for simulated data of chunk500 + for dset in swb_sre_tr swb_sre_cv; do + if [ "$dset" == "swb_sre_tr" ]; then + n_mixtures=${simu_opts_num_train} + else + n_mixtures=500 + fi + simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + mkdir ${data_dir}/simu/data/${simu_data_dir}/.work + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" + done + utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 + python local/split.py ${data_dir}/simu/data/${simu_data_dir} + done +fi + # Training on simulated two-speaker data world_size=$gpu_num diff --git a/funasr/modules/eend_ola/utils/feature.py b/funasr/modules/eend_ola/utils/feature.py new file mode 100644 index 000000000..544a3521d --- /dev/null +++ b/funasr/modules/eend_ola/utils/feature.py @@ -0,0 +1,286 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This module is for computing audio features + +import numpy as np +import librosa + + +def get_input_dim( + frame_size, + context_size, + transform_type, +): + if transform_type.startswith('logmel23'): + frame_size = 23 + elif transform_type.startswith('logmel'): + frame_size = 40 + else: + fft_size = 1 << (frame_size - 1).bit_length() + frame_size = int(fft_size / 2) + 1 + input_dim = (2 * context_size + 1) * frame_size + return input_dim + + +def transform( + Y, + transform_type=None, + dtype=np.float32): + """ Transform STFT feature + + Args: + Y: STFT + (n_frames, n_bins)-shaped np.complex array + transform_type: + None, "log" + dtype: output data type + np.float32 is expected + Returns: + Y (numpy.array): transformed feature + """ + Y = np.abs(Y) + if not transform_type: + pass + elif transform_type == 'log': + Y = np.log(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel': + n_fft = 2 * (Y.shape[1] - 1) + sr = 16000 + n_mels = 40 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel23': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel23_mn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + elif transform_type == 'logmel23_swn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + # b = np.ones(300)/300 + # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same') + + # simple 2-means based threshoding for mean calculation + powers = np.sum(Y, axis=1) + th = (np.max(powers) + np.min(powers)) / 2.0 + for i in range(10): + th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2 + mean = np.mean(Y[powers > th, :], axis=0) + Y = Y - mean + elif transform_type == 'logmel23_mvn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + std = np.maximum(np.std(Y, axis=0), 1e-10) + Y = Y / std + else: + raise ValueError('Unknown transform_type: %s' % transform_type) + return Y.astype(dtype) + + +def subsample(Y, T, subsampling=1): + """ Frame subsampling + """ + Y_ss = Y[::subsampling] + T_ss = T[::subsampling] + return Y_ss, T_ss + + +def splice(Y, context_size=0): + """ Frame splicing + + Args: + Y: feature + (n_frames, n_featdim)-shaped numpy array + context_size: + number of frames concatenated on left-side + if context_size = 5, 11 frames are concatenated. + + Returns: + Y_spliced: spliced feature + (n_frames, n_featdim * (2 * context_size + 1))-shaped + """ + Y_pad = np.pad( + Y, + [(context_size, context_size), (0, 0)], + 'constant') + Y_spliced = np.lib.stride_tricks.as_strided( + np.ascontiguousarray(Y_pad), + (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), + (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) + return Y_spliced + + +def stft( + data, + frame_size=1024, + frame_shift=256): + """ Compute STFT features + + Args: + data: audio signal + (n_samples,)-shaped np.float32 array + frame_size: number of samples in a frame (must be a power of two) + frame_shift: number of samples between frames + + Returns: + stft: STFT frames + (n_frames, n_bins)-shaped np.complex64 array + """ + # round up to nearest power of 2 + fft_size = 1 << (frame_size - 1).bit_length() + # HACK: The last frame is ommited + # as librosa.stft produces such an excessive frame + if len(data) % frame_shift == 0: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T[:-1] + else: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T + + +def _count_frames(data_len, size, shift): + # HACK: Assuming librosa.stft(..., center=True) + n_frames = 1 + int(data_len / shift) + if data_len % shift == 0: + n_frames = n_frames - 1 + return n_frames + + +def get_frame_labels( + kaldi_obj, + rec, + start=0, + end=None, + frame_size=1024, + frame_shift=256, + n_speakers=None): + """ Get frame-aligned labels of given recording + Args: + kaldi_obj (KaldiData) + rec (str): recording id + start (int): start frame index + end (int): end frame index + None means the last frame of recording + frame_size (int): number of frames in a frame + frame_shift (int): number of shift samples + n_speakers (int): number of speakers + if None, the value is given from data + Returns: + T: label + (n_frames, n_speakers)-shaped np.int32 array + """ + filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] + speakers = np.unique( + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() + if n_speakers is None: + n_speakers = len(speakers) + es = end * frame_shift if end is not None else None + data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es) + n_frames = _count_frames(len(data), frame_size, frame_shift) + T = np.zeros((n_frames, n_speakers), dtype=np.int32) + if end is None: + end = n_frames + + for seg in filtered_segments: + speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) + start_frame = np.rint( + seg['st'] * rate / frame_shift).astype(int) + end_frame = np.rint( + seg['et'] * rate / frame_shift).astype(int) + rel_start = rel_end = None + if start <= start_frame and start_frame < end: + rel_start = start_frame - start + if start < end_frame and end_frame <= end: + rel_end = end_frame - start + if rel_start is not None or rel_end is not None: + T[rel_start:rel_end, speaker_index] = 1 + return T + + +def get_labeledSTFT( + kaldi_obj, + rec, start, end, frame_size, frame_shift, + n_speakers=None, + use_speaker_id=False): + """ Extracts STFT and corresponding labels + + Extracts STFT and corresponding diarization labels for + given recording id and start/end times + + Args: + kaldi_obj (KaldiData) + rec (str): recording id + start (int): start frame index + end (int): end frame index + frame_size (int): number of samples in a frame + frame_shift (int): number of shift samples + n_speakers (int): number of speakers + if None, the value is given from data + Returns: + Y: STFT + (n_frames, n_bins)-shaped np.complex64 array, + T: label + (n_frmaes, n_speakers)-shaped np.int32 array. + """ + data, rate = kaldi_obj.load_wav( + rec, start * frame_shift, end * frame_shift) + Y = stft(data, frame_size, frame_shift) + filtered_segments = kaldi_obj.segments[rec] + # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] + speakers = np.unique( + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() + if n_speakers is None: + n_speakers = len(speakers) + T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32) + + if use_speaker_id: + all_speakers = sorted(kaldi_obj.spk2utt.keys()) + S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32) + + for seg in filtered_segments: + speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) + if use_speaker_id: + all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']]) + start_frame = np.rint( + seg['st'] * rate / frame_shift).astype(int) + end_frame = np.rint( + seg['et'] * rate / frame_shift).astype(int) + rel_start = rel_end = None + if start <= start_frame and start_frame < end: + rel_start = start_frame - start + if start < end_frame and end_frame <= end: + rel_end = end_frame - start + if rel_start is not None or rel_end is not None: + T[rel_start:rel_end, speaker_index] = 1 + if use_speaker_id: + S[rel_start:rel_end, all_speaker_index] = 1 + + if use_speaker_id: + return Y, T, S + else: + return Y, T