diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index 8265fc590..588b1bc55 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -42,7 +42,7 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export -from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard class Speech2Text: @@ -291,7 +291,10 @@ class Speech2Text: text = None if isinstance(self.asr_model, BiCifParaformer): - timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], + us_cif_peak[i], + copy.copy(token), + vad_offset=begin_time) results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) else: results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 13208778f..1dc98f6e0 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -44,11 +44,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask from funasr.bin.vad_inference import Speech2VadSegment -from funasr.utils.timestamp_tools import time_stamp_lfr6_pl +from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard from funasr.bin.punctuation_infer import Text2Punc from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer -from funasr.utils.timestamp_tools import time_stamp_sentence header_colors = '\033[95m' end_colors = '\033[0m' @@ -303,7 +302,10 @@ class Speech2Text: text = None if isinstance(self.asr_model, BiCifParaformer): - timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], + us_cif_peak[i], + copy.copy(token), + vad_offset=begin_time) results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) else: results.append((text, token, token_int, enc_len_batch_total, lfr_factor)) diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py index e7a1f1b68..e374a227a 100644 --- a/funasr/bin/tp_inference.py +++ b/funasr/bin/tp_inference.py @@ -28,6 +28,8 @@ from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.models.frontend.wav_frontend import WavFrontend from funasr.text.token_id_converter import TokenIDConverter +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard + header_colors = '\033[95m' end_colors = '\033[0m' @@ -38,61 +40,6 @@ global_sample_rate: Union[int, Dict[Any, int]] = { 'model_fs': 16000 } -def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list): - START_END_THRESHOLD = 5 - MAX_TOKEN_DURATION = 12 - TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled - if len(us_cif_peak.shape) == 2: - alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only - else: - alphas, cif_peak = us_alphas, us_cif_peak - num_frames = cif_peak.shape[0] - if char_list[-1] == '': - char_list = char_list[:-1] - # char_list = [i for i in text] - timestamp_list = [] - new_char_list = [] - # for bicif model trained with large data, cif2 actually fires when a character starts - # so treat the frames between two peaks as the duration of the former token - fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset - num_peak = len(fire_place) - assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 - # begin silence - if fire_place[0] > START_END_THRESHOLD: - # char_list.insert(0, '') - timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) - new_char_list.append('') - # tokens timestamp - for i in range(len(fire_place)-1): - new_char_list.append(char_list[i]) - if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION: - timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE]) - else: - # cut the duration to token and sil of the 0-weight frames last long - _split = fire_place[i] + MAX_TOKEN_DURATION - timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE]) - timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE]) - new_char_list.append('') - # tail token and end silence - # new_char_list.append(char_list[-1]) - if num_frames - fire_place[-1] > START_END_THRESHOLD: - _end = (num_frames + fire_place[-1]) * 0.5 - # _end = fire_place[-1] - timestamp_list[-1][1] = _end*TIME_RATE - timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) - new_char_list.append("") - else: - timestamp_list[-1][1] = num_frames*TIME_RATE - assert len(new_char_list) == len(timestamp_list) - res_str = "" - for char, timestamp in zip(new_char_list, timestamp_list): - res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) - res = [] - for char, timestamp in zip(new_char_list, timestamp_list): - if char != '': - res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) - return res_str, res - class SpeechText2Timestamp: def __init__( @@ -315,7 +262,7 @@ def inference_modelscope( for batch_id in range(_bs): key = keys[batch_id] token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id]) - ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token) + ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0) logging.warning(ts_str) item = {'key': key, 'value': ts_str, 'timestamp':ts_list} tp_result_list.append(item) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 4a367f8ec..f8adbbc76 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -5,55 +5,70 @@ import numpy as np from typing import Any, List, Tuple, Union -def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): +def ts_prediction_lfr6_standard(us_alphas, + us_cif_peak, + char_list, + vad_offset=0.0, + end_time=None, + force_time_shift=-1.5 + ): if not len(char_list): return [] START_END_THRESHOLD = 5 + MAX_TOKEN_DURATION = 12 TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled - if len(us_alphas.shape) == 3: + if len(us_alphas.shape) == 2: alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only else: alphas, cif_peak = us_alphas, us_cif_peak num_frames = cif_peak.shape[0] if char_list[-1] == '': char_list = char_list[:-1] - # char_list = [i for i in text] timestamp_list = [] + new_char_list = [] # for bicif model trained with large data, cif2 actually fires when a character starts # so treat the frames between two peaks as the duration of the former token - fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5 + fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset num_peak = len(fire_place) assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 # begin silence if fire_place[0] > START_END_THRESHOLD: - char_list.insert(0, '') + # char_list.insert(0, '') timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) + new_char_list.append('') # tokens timestamp for i in range(len(fire_place)-1): - # the peak is always a little ahead of the start time - # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE]) - timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE]) - # cut the duration to token and sil of the 0-weight frames last long + new_char_list.append(char_list[i]) + if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION: + timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE]) + else: + # cut the duration to token and sil of the 0-weight frames last long + _split = fire_place[i] + MAX_TOKEN_DURATION + timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE]) + timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE]) + new_char_list.append('') # tail token and end silence + # new_char_list.append(char_list[-1]) if num_frames - fire_place[-1] > START_END_THRESHOLD: - _end = (num_frames + fire_place[-1]) / 2 + _end = (num_frames + fire_place[-1]) * 0.5 + # _end = fire_place[-1] timestamp_list[-1][1] = _end*TIME_RATE timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) - char_list.append("") + new_char_list.append("") else: timestamp_list[-1][1] = num_frames*TIME_RATE - if begin_time: # add offset time in model with vad + if vad_offset: # add offset time in model with vad for i in range(len(timestamp_list)): - timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 - timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 + timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0 + timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0 res_txt = "" - for char, timestamp in zip(char_list, timestamp_list): - res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) + for char, timestamp in zip(new_char_list, timestamp_list): + res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) res = [] - for char, timestamp in zip(char_list, timestamp_list): + for char, timestamp in zip(new_char_list, timestamp_list): if char != '': res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) - return res + return res_txt, res def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):