FunASR/egs/aishell2/transformer/utils/compute_fbank.py
2023-05-12 17:25:54 +08:00

172 lines
5.1 KiB
Python
Executable File

from kaldiio import WriteHelper
import argparse
import numpy as np
import json
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
resample_rate=16000,
speed=1.0,
window_type="hamming"):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, resample_rate,
[['speed', str(speed)], ['rate', str(resample_rate)]]
)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type=window_type,
sample_frequency=resample_rate)
return mat.numpy()
def get_parser():
parser = argparse.ArgumentParser(
description="computer features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--wav-lists",
"-w",
default=False,
required=True,
type=str,
help="input wav lists",
)
parser.add_argument(
"--text-files",
"-t",
default=False,
required=True,
type=str,
help="input text files",
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--max-lengths",
"-m",
default=1500,
type=int,
help="max frame numbers",
)
parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
type=int,
help="sample frequency",
)
parser.add_argument(
"--speed-perturb",
"-p",
default="1.0",
type=str,
help="speed perturb",
)
parser.add_argument(
"--ark-index",
"-a",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
parser.add_argument(
"--window-type",
default="hamming",
required=False,
type=str,
help="window type"
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
text_writer = open(text_file, 'w')
feats_shape_writer = open(feats_shape_file, 'w')
text_shape_writer = open(text_shape_file, 'w')
speed_perturb_list = args.speed_perturb.split(',')
for speed in speed_perturb_list:
with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
with open(args.text_files, 'r', encoding='utf-8') as textfile:
for wav, text in zip(wavfile, textfile):
s_w = wav.strip().split()
wav_id = s_w[0]
wav_file = s_w[1]
s_t = text.strip().split()
text_id = s_t[0]
txt = s_t[1:]
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
speed=float(speed),
window_type=args.window_type
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
if feats_lens >= args.max_lengths:
continue
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
else:
wav_id_sp = wav_id + "_sp" + speed
feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
ark_writer(wav_id_sp, fbank)
if __name__ == '__main__':
main()