release timestasmp related tools

This commit is contained in:
shixian.shi 2023-03-15 10:21:32 +08:00
parent 0b06794fde
commit f59a72d24e

View File

@ -1,3 +1,4 @@
from pydoc import TextRepr
from scipy.fftpack import shift
import torch
import copy
@ -5,6 +6,7 @@ import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
@ -13,7 +15,8 @@ def ts_prediction_lfr6_standard(us_alphas,
us_peaks,
char_list,
vad_offset=0.0,
force_time_shift=-1.5
force_time_shift=-1.5,
sil_in_str=True
):
if not len(char_list):
return []
@ -66,6 +69,8 @@ def ts_prediction_lfr6_standard(us_alphas,
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
#if char != '<sil>':
if not sil_in_str and char == '<sil>': continue
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
@ -233,13 +238,54 @@ class AverageShiftCalculator():
return self._accumlated_shift / self._accumlated_tokens
SUPPORTED_MODES = ['cal_aas']
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr.models.predictor.cif import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
if args.mode == 'cal_aas':
asc = AverageShiftCalculator()
asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))