From f59a72d24e917fb2e9560fa646ae80285dba6c95 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 15 Mar 2023 10:21:32 +0800 Subject: [PATCH] release timestasmp related tools --- funasr/utils/timestamp_tools.py | 50 +++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 2bccd50e6..09c3becfc 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -1,3 +1,4 @@ +from pydoc import TextRepr from scipy.fftpack import shift import torch import copy @@ -5,6 +6,7 @@ import codecs import logging import edit_distance import argparse +import pdb import numpy as np from typing import Any, List, Tuple, Union @@ -13,7 +15,8 @@ def ts_prediction_lfr6_standard(us_alphas, us_peaks, char_list, vad_offset=0.0, - force_time_shift=-1.5 + force_time_shift=-1.5, + sil_in_str=True ): if not len(char_list): return [] @@ -66,6 +69,8 @@ def ts_prediction_lfr6_standard(us_alphas, timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0 res_txt = "" for char, timestamp in zip(new_char_list, timestamp_list): + #if char != '': + if not sil_in_str and char == '': continue res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) res = [] for char, timestamp in zip(new_char_list, timestamp_list): @@ -233,13 +238,54 @@ class AverageShiftCalculator(): return self._accumlated_shift / self._accumlated_tokens -SUPPORTED_MODES = ['cal_aas'] +def convert_external_alphas(alphas_file, text_file, output_file): + from funasr.models.predictor.cif import cif_wo_hidden + with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3: + for line1, line2 in zip(f1.readlines(), f2.readlines()): + line1 = line1.rstrip() + line2 = line2.rstrip() + assert line1.split()[0] == line2.split()[0] + uttid = line1.split()[0] + alphas = [float(i) for i in line1.split()[1:]] + new_alphas = np.array(remove_chunk_padding(alphas)) + new_alphas[-1] += 1e-4 + text = line2.split()[1:] + if len(text) + 1 != int(new_alphas.sum()): + # force resize + new_alphas *= (len(text) + 1) / int(new_alphas.sum()) + peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4) + if " " in text: + text = text.split() + else: + text = [i for i in text] + res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text, + force_time_shift=-7.0, + sil_in_str=False) + f3.write("{} {}\n".format(uttid, res_str)) + + +def remove_chunk_padding(alphas): + # remove the padding part in alphas if using chunk paraformer for GPU + START_ZERO = 45 + MID_ZERO = 75 + REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5 + alphas = alphas[START_ZERO:] # remove the padding at beginning + new_alphas = [] + while True: + new_alphas = new_alphas + alphas[:REAL_FRAMES] + alphas = alphas[REAL_FRAMES+MID_ZERO:] + if len(alphas) < REAL_FRAMES: break + return new_alphas + +SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas'] def main(args): if args.mode == 'cal_aas': asc = AverageShiftCalculator() asc(args.input, args.input2) + elif args.mode == 'read_ext_alphas': + convert_external_alphas(args.input, args.input2, args.output) else: logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))