# Copyright (c) Alibaba, Inc. and its affiliates. import math import os import shutil from multiprocessing import Pool from typing import Any, Dict, Union import kaldiio import librosa import numpy as np import torch import torchaudio import torchaudio.compliance.kaldi as kaldi def ndarray_resample(audio_in: np.ndarray, fs_in: int = 16000, fs_out: int = 16000) -> np.ndarray: audio_out = audio_in if fs_in != fs_out: audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out) return audio_out def torch_resample(audio_in: torch.Tensor, fs_in: int = 16000, fs_out: int = 16000) -> torch.Tensor: audio_out = audio_in if fs_in != fs_out: audio_out = torchaudio.transforms.Resample(orig_freq=fs_in, new_freq=fs_out)(audio_in) return audio_out def extract_CMVN_featrures(mvn_file): """ extract CMVN from cmvn.ark """ if not os.path.exists(mvn_file): return None try: cmvn = kaldiio.load_mat(mvn_file) means = [] variance = [] for i in range(cmvn.shape[1] - 1): means.append(float(cmvn[0][i])) count = float(cmvn[0][-1]) for i in range(cmvn.shape[1] - 1): variance.append(float(cmvn[1][i])) for i in range(len(means)): means[i] /= count variance[i] = variance[i] / count - means[i] * means[i] if variance[i] < 1.0e-20: variance[i] = 1.0e-20 variance[i] = 1.0 / math.sqrt(variance[i]) cmvn = np.array([means, variance]) return cmvn except Exception: cmvn = extract_CMVN_features_txt(mvn_file) return cmvn def extract_CMVN_features_txt(mvn_file): # noqa with open(mvn_file, 'r', encoding='utf-8') as f: lines = f.readlines() add_shift_list = [] rescale_list = [] for i in range(len(lines)): line_item = lines[i].split() if line_item[0] == '': line_item = lines[i + 1].split() if line_item[0] == '': add_shift_line = line_item[3:(len(line_item) - 1)] add_shift_list = list(add_shift_line) continue elif line_item[0] == '': line_item = lines[i + 1].split() if line_item[0] == '': rescale_line = line_item[3:(len(line_item) - 1)] rescale_list = list(rescale_line) continue add_shift_list_f = [float(s) for s in add_shift_list] rescale_list_f = [float(s) for s in rescale_list] cmvn = np.array([add_shift_list_f, rescale_list_f]) return cmvn def build_LFR_features(inputs, m=7, n=6): # noqa """ Actually, this implements stacking frames and skipping frames. if m = 1 and n = 1, just return the origin features. if m = 1 and n > 1, it works like skipping. if m > 1 and n = 1, it works like stacking but only support right frames. if m > 1 and n > 1, it works like LFR. Args: inputs_batch: inputs is T x D np.ndarray m: number of frames to stack n: number of frames to skip """ # LFR_inputs_batch = [] # for inputs in inputs_batch: LFR_inputs = [] T = inputs.shape[0] T_lfr = int(np.ceil(T / n)) left_padding = np.tile(inputs[0], ((m - 1) // 2, 1)) inputs = np.vstack((left_padding, inputs)) T = T + (m - 1) // 2 for i in range(T_lfr): if m <= T - i * n: LFR_inputs.append(np.hstack(inputs[i * n:i * n + m])) else: # process last LFR frame num_padding = m - (T - i * n) frame = np.hstack(inputs[i * n:]) for _ in range(num_padding): frame = np.hstack((frame, inputs[-1])) LFR_inputs.append(frame) return np.vstack(LFR_inputs) def compute_fbank(wav_file, num_mel_bins=80, frame_length=25, frame_shift=10, dither=0.0, is_pcm=False, fs: Union[int, Dict[Any, int]] = 16000): audio_sr: int = 16000 model_sr: int = 16000 if isinstance(fs, int): model_sr = fs audio_sr = fs else: model_sr = fs['model_fs'] audio_sr = fs['audio_fs'] if is_pcm is True: # byte(PCM16) to float32, and resample value = wav_file middle_data = np.frombuffer(value, dtype=np.int16) middle_data = np.asarray(middle_data) if middle_data.dtype.kind not in 'iu': raise TypeError("'middle_data' must be an array of integers") dtype = np.dtype('float32') if dtype.kind != 'f': raise TypeError("'dtype' must be a floating point type") i = np.iinfo(middle_data.dtype) abs_max = 2 ** (i.bits - 1) offset = i.min + abs_max waveform = np.frombuffer( (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) waveform = ndarray_resample(waveform, audio_sr, model_sr) waveform = torch.from_numpy(waveform.reshape(1, -1)) else: # load pcm from wav, and resample waveform, audio_sr = torchaudio.load(wav_file) waveform = waveform * (1 << 15) waveform = torch_resample(waveform, audio_sr, model_sr) mat = kaldi.fbank(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, window_type='hamming', sample_frequency=model_sr) input_feats = mat return input_feats def wav2num_frame(wav_path, frontend_conf): waveform, sampling_rate = torchaudio.load(wav_path) speech_length = (waveform.shape[1] / sampling_rate) * 1000. n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"]) feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"] return n_frames, feature_dim, speech_length def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx): wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx)) shape_file = os.path.join(root_path, "speech_shape.{}".format(idx)) with open(wav_scp_file) as f: lines = f.readlines() with open(shape_file, "w") as f: for line in lines: sample_name, wav_path = line.strip().split() n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf) write_flag = True if speech_length_min > 0 and speech_length < speech_length_min: write_flag = False if speech_length_max > 0 and speech_length > speech_length_max: write_flag = False if write_flag: f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim)))) f.flush() def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32): shape_path = os.path.join(data_dir, dataset, "shape_files") if os.path.exists(shape_path): assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape")) print('Shape file for small dataset already exists.') return os.makedirs(shape_path, exist_ok=True) # split wav_scp_file = os.path.join(data_dir, dataset, "wav.scp") with open(wav_scp_file) as f: lines = f.readlines() num_lines = len(lines) num_job_lines = num_lines // nj start = 0 for i in range(nj): end = start + num_job_lines file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1))) with open(file, "w") as f: if i == nj - 1: f.writelines(lines[start:]) else: f.writelines(lines[start:end]) start = end p = Pool(nj) for i in range(nj): p.apply_async(calc_shape_core, args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1))) print('Generating shape files, please wait a few minutes...') p.close() p.join() # combine file = os.path.join(data_dir, dataset, "speech_shape") with open(file, "w") as f: for i in range(nj): job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1))) with open(job_file) as job_f: lines = job_f.readlines() f.writelines(lines) print('Generating shape files done.') def generate_data_list(data_dir, dataset, nj=100): split_dir = os.path.join(data_dir, dataset, "split") if os.path.exists(split_dir): assert os.path.exists(os.path.join(data_dir, dataset, "data.list")) print('Data list for large dataset already exists.') return os.makedirs(split_dir, exist_ok=True) with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav: wav_lines = f_wav.readlines() with open(os.path.join(data_dir, dataset, "text")) as f_text: text_lines = f_text.readlines() total_num_lines = len(wav_lines) num_lines = total_num_lines // nj start_num = 0 for i in range(nj): end_num = start_num + num_lines split_dir_nj = os.path.join(split_dir, str(i + 1)) os.mkdir(split_dir_nj) wav_file = os.path.join(split_dir_nj, 'wav.scp') text_file = os.path.join(split_dir_nj, "text") with open(wav_file, "w") as fw, open(text_file, "w") as ft: if i == nj - 1: fw.writelines(wav_lines[start_num:]) ft.writelines(text_lines[start_num:]) else: fw.writelines(wav_lines[start_num:end_num]) ft.writelines(text_lines[start_num:end_num]) start_num = end_num data_list_file = os.path.join(data_dir, dataset, "data.list") with open(data_list_file, "w") as f_data: for i in range(nj): wav_path = os.path.join(split_dir, str(i + 1), "wav.scp") text_path = os.path.join(split_dir, str(i + 1), "text") f_data.write(wav_path + " " + text_path + "\n") def filter_wav_text(data_dir, dataset): wav_file = os.path.join(data_dir,dataset,"wav.scp") text_file = os.path.join(data_dir, dataset, "text") with open(wav_file) as f_wav, open(text_file) as f_text: wav_lines = f_wav.readlines() text_lines = f_text.readlines() os.rename(wav_file, "{}.bak".format(wav_file)) os.rename(text_file, "{}.bak".format(text_file)) wav_dict = {} for line in wav_lines: parts = line.strip().split() if len(parts) < 2: continue sample_name, wav_path = parts wav_dict[sample_name] = wav_path text_dict = {} for line in text_lines: parts = line.strip().split() if len(parts) < 2: continue sample_name = parts[0] text_dict[sample_name] = " ".join(parts[1:]).lower() filter_count = 0 with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text: for sample_name, wav_path in wav_dict.items(): if sample_name in text_dict.keys(): f_wav.write(sample_name + " " + wav_path + "\n") f_text.write(sample_name + " " + text_dict[sample_name] + "\n") else: filter_count += 1 print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))