mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
release timestasmp related tools
This commit is contained in:
parent
0b06794fde
commit
f59a72d24e
@ -1,3 +1,4 @@
|
|||||||
|
from pydoc import TextRepr
|
||||||
from scipy.fftpack import shift
|
from scipy.fftpack import shift
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
@ -5,6 +6,7 @@ import codecs
|
|||||||
import logging
|
import logging
|
||||||
import edit_distance
|
import edit_distance
|
||||||
import argparse
|
import argparse
|
||||||
|
import pdb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
@ -13,7 +15,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|||||||
us_peaks,
|
us_peaks,
|
||||||
char_list,
|
char_list,
|
||||||
vad_offset=0.0,
|
vad_offset=0.0,
|
||||||
force_time_shift=-1.5
|
force_time_shift=-1.5,
|
||||||
|
sil_in_str=True
|
||||||
):
|
):
|
||||||
if not len(char_list):
|
if not len(char_list):
|
||||||
return []
|
return []
|
||||||
@ -66,6 +69,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|||||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||||
res_txt = ""
|
res_txt = ""
|
||||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||||
|
#if char != '<sil>':
|
||||||
|
if not sil_in_str and char == '<sil>': continue
|
||||||
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||||
res = []
|
res = []
|
||||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||||
@ -233,13 +238,54 @@ class AverageShiftCalculator():
|
|||||||
return self._accumlated_shift / self._accumlated_tokens
|
return self._accumlated_shift / self._accumlated_tokens
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_MODES = ['cal_aas']
|
def convert_external_alphas(alphas_file, text_file, output_file):
|
||||||
|
from funasr.models.predictor.cif import cif_wo_hidden
|
||||||
|
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
|
||||||
|
for line1, line2 in zip(f1.readlines(), f2.readlines()):
|
||||||
|
line1 = line1.rstrip()
|
||||||
|
line2 = line2.rstrip()
|
||||||
|
assert line1.split()[0] == line2.split()[0]
|
||||||
|
uttid = line1.split()[0]
|
||||||
|
alphas = [float(i) for i in line1.split()[1:]]
|
||||||
|
new_alphas = np.array(remove_chunk_padding(alphas))
|
||||||
|
new_alphas[-1] += 1e-4
|
||||||
|
text = line2.split()[1:]
|
||||||
|
if len(text) + 1 != int(new_alphas.sum()):
|
||||||
|
# force resize
|
||||||
|
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
|
||||||
|
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
|
||||||
|
if " " in text:
|
||||||
|
text = text.split()
|
||||||
|
else:
|
||||||
|
text = [i for i in text]
|
||||||
|
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
|
||||||
|
force_time_shift=-7.0,
|
||||||
|
sil_in_str=False)
|
||||||
|
f3.write("{} {}\n".format(uttid, res_str))
|
||||||
|
|
||||||
|
|
||||||
|
def remove_chunk_padding(alphas):
|
||||||
|
# remove the padding part in alphas if using chunk paraformer for GPU
|
||||||
|
START_ZERO = 45
|
||||||
|
MID_ZERO = 75
|
||||||
|
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
|
||||||
|
alphas = alphas[START_ZERO:] # remove the padding at beginning
|
||||||
|
new_alphas = []
|
||||||
|
while True:
|
||||||
|
new_alphas = new_alphas + alphas[:REAL_FRAMES]
|
||||||
|
alphas = alphas[REAL_FRAMES+MID_ZERO:]
|
||||||
|
if len(alphas) < REAL_FRAMES: break
|
||||||
|
return new_alphas
|
||||||
|
|
||||||
|
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.mode == 'cal_aas':
|
if args.mode == 'cal_aas':
|
||||||
asc = AverageShiftCalculator()
|
asc = AverageShiftCalculator()
|
||||||
asc(args.input, args.input2)
|
asc(args.input, args.input2)
|
||||||
|
elif args.mode == 'read_ext_alphas':
|
||||||
|
convert_external_alphas(args.input, args.input2, args.output)
|
||||||
else:
|
else:
|
||||||
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user