FunASR/funasr/utils/wav_utils.py
2023-01-16 18:46:40 +08:00

289 lines
10 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import shutil
from multiprocessing import Pool
from typing import Any, Dict, Union
import kaldiio
import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def ndarray_resample(audio_in: np.ndarray,
fs_in: int = 16000,
fs_out: int = 16000) -> np.ndarray:
audio_out = audio_in
if fs_in != fs_out:
audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
return audio_out
def torch_resample(audio_in: torch.Tensor,
fs_in: int = 16000,
fs_out: int = 16000) -> torch.Tensor:
audio_out = audio_in
if fs_in != fs_out:
audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
new_freq=fs_out)(audio_in)
return audio_out
def extract_CMVN_featrures(mvn_file):
"""
extract CMVN from cmvn.ark
"""
if not os.path.exists(mvn_file):
return None
try:
cmvn = kaldiio.load_mat(mvn_file)
means = []
variance = []
for i in range(cmvn.shape[1] - 1):
means.append(float(cmvn[0][i]))
count = float(cmvn[0][-1])
for i in range(cmvn.shape[1] - 1):
variance.append(float(cmvn[1][i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
except Exception:
cmvn = extract_CMVN_features_txt(mvn_file)
return cmvn
def extract_CMVN_features_txt(mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
add_shift_list_f = [float(s) for s in add_shift_list]
rescale_list_f = [float(s) for s in rescale_list]
cmvn = np.array([add_shift_list_f, rescale_list_f])
return cmvn
def build_LFR_features(inputs, m=7, n=6): # noqa
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
# LFR_inputs_batch = []
# for inputs in inputs_batch:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else: # process last LFR frame
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
is_pcm=False,
fs: Union[int, Dict[Any, int]] = 16000):
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(fs, int):
model_sr = fs
audio_sr = fs
else:
model_sr = fs['model_fs']
audio_sr = fs['audio_fs']
if is_pcm is True:
# byte(PCM16) to float32, and resample
value = wav_file
middle_data = np.frombuffer(value, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
waveform = np.frombuffer(
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
waveform = ndarray_resample(waveform, audio_sr, model_sr)
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
waveform, audio_sr = torchaudio.load(wav_file)
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=model_sr)
input_feats = mat
return input_feats
def wav2num_frame(wav_path, frontend_conf):
waveform, sampling_rate = torchaudio.load(wav_path)
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
return n_frames, feature_dim, speech_length
def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
with open(wav_scp_file) as f:
lines = f.readlines()
with open(shape_file, "w") as f:
for line in lines:
sample_name, wav_path = line.strip().split()
n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
write_flag = True
if speech_length_min > 0 and speech_length < speech_length_min:
write_flag = False
if speech_length_max > 0 and speech_length > speech_length_max:
write_flag = False
if write_flag:
f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32):
shape_path = os.path.join(data_dir, dataset, "shape_files")
if os.path.exists(shape_path):
assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape"))
print('Shape file for small dataset already exists.')
return
os.makedirs(shape_path, exist_ok=True)
# split
wav_scp_file = os.path.join(data_dir, dataset, "wav.scp")
with open(wav_scp_file) as f:
lines = f.readlines()
num_lines = len(lines)
num_job_lines = num_lines // nj
start = 0
for i in range(nj):
end = start + num_job_lines
file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1)))
with open(file, "w") as f:
if i == nj - 1:
f.writelines(lines[start:])
else:
f.writelines(lines[start:end])
start = end
p = Pool(nj)
for i in range(nj):
p.apply_async(calc_shape_core,
args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
print('Generating shape files, please wait a few minutes...')
p.close()
p.join()
# combine
file = os.path.join(data_dir, dataset, "speech_shape")
with open(file, "w") as f:
for i in range(nj):
job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
with open(job_file) as job_f:
lines = job_f.readlines()
f.writelines(lines)
print('Generating shape files done.')
def generate_data_list(data_dir, dataset, nj=100):
split_dir = os.path.join(data_dir, dataset, "split")
if os.path.exists(split_dir):
assert os.path.exists(os.path.join(data_dir, dataset, "data.list"))
print('Data list for large dataset already exists.')
return
os.makedirs(split_dir, exist_ok=True)
with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
wav_lines = f_wav.readlines()
with open(os.path.join(data_dir, dataset, "text")) as f_text:
text_lines = f_text.readlines()
total_num_lines = len(wav_lines)
num_lines = total_num_lines // nj
start_num = 0
for i in range(nj):
end_num = start_num + num_lines
split_dir_nj = os.path.join(split_dir, str(i + 1))
os.mkdir(split_dir_nj)
wav_file = os.path.join(split_dir_nj, 'wav.scp')
text_file = os.path.join(split_dir_nj, "text")
with open(wav_file, "w") as fw, open(text_file, "w") as ft:
if i == nj - 1:
fw.writelines(wav_lines[start_num:])
ft.writelines(text_lines[start_num:])
else:
fw.writelines(wav_lines[start_num:end_num])
ft.writelines(text_lines[start_num:end_num])
start_num = end_num
data_list_file = os.path.join(data_dir, dataset, "data.list")
with open(data_list_file, "w") as f_data:
for i in range(nj):
wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
text_path = os.path.join(split_dir, str(i + 1), "text")
f_data.write(wav_path + " " + text_path + "\n")