mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
release timestasmp related tools
This commit is contained in:
parent
0b06794fde
commit
f59a72d24e
@ -1,3 +1,4 @@
|
||||
from pydoc import TextRepr
|
||||
from scipy.fftpack import shift
|
||||
import torch
|
||||
import copy
|
||||
@ -5,6 +6,7 @@ import codecs
|
||||
import logging
|
||||
import edit_distance
|
||||
import argparse
|
||||
import pdb
|
||||
import numpy as np
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
@ -13,7 +15,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
||||
us_peaks,
|
||||
char_list,
|
||||
vad_offset=0.0,
|
||||
force_time_shift=-1.5
|
||||
force_time_shift=-1.5,
|
||||
sil_in_str=True
|
||||
):
|
||||
if not len(char_list):
|
||||
return []
|
||||
@ -66,6 +69,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||
res_txt = ""
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
#if char != '<sil>':
|
||||
if not sil_in_str and char == '<sil>': continue
|
||||
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||
res = []
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
@ -233,13 +238,54 @@ class AverageShiftCalculator():
|
||||
return self._accumlated_shift / self._accumlated_tokens
|
||||
|
||||
|
||||
SUPPORTED_MODES = ['cal_aas']
|
||||
def convert_external_alphas(alphas_file, text_file, output_file):
|
||||
from funasr.models.predictor.cif import cif_wo_hidden
|
||||
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
|
||||
for line1, line2 in zip(f1.readlines(), f2.readlines()):
|
||||
line1 = line1.rstrip()
|
||||
line2 = line2.rstrip()
|
||||
assert line1.split()[0] == line2.split()[0]
|
||||
uttid = line1.split()[0]
|
||||
alphas = [float(i) for i in line1.split()[1:]]
|
||||
new_alphas = np.array(remove_chunk_padding(alphas))
|
||||
new_alphas[-1] += 1e-4
|
||||
text = line2.split()[1:]
|
||||
if len(text) + 1 != int(new_alphas.sum()):
|
||||
# force resize
|
||||
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
|
||||
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
|
||||
if " " in text:
|
||||
text = text.split()
|
||||
else:
|
||||
text = [i for i in text]
|
||||
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
|
||||
force_time_shift=-7.0,
|
||||
sil_in_str=False)
|
||||
f3.write("{} {}\n".format(uttid, res_str))
|
||||
|
||||
|
||||
def remove_chunk_padding(alphas):
|
||||
# remove the padding part in alphas if using chunk paraformer for GPU
|
||||
START_ZERO = 45
|
||||
MID_ZERO = 75
|
||||
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
|
||||
alphas = alphas[START_ZERO:] # remove the padding at beginning
|
||||
new_alphas = []
|
||||
while True:
|
||||
new_alphas = new_alphas + alphas[:REAL_FRAMES]
|
||||
alphas = alphas[REAL_FRAMES+MID_ZERO:]
|
||||
if len(alphas) < REAL_FRAMES: break
|
||||
return new_alphas
|
||||
|
||||
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == 'cal_aas':
|
||||
asc = AverageShiftCalculator()
|
||||
asc(args.input, args.input2)
|
||||
elif args.mode == 'read_ext_alphas':
|
||||
convert_external_alphas(args.input, args.input2, args.output)
|
||||
else:
|
||||
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user