diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 0d9bb2be1..85967af3e 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -4,6 +4,7 @@ # MIT License (https://opensource.org/licenses/MIT) import time +import copy import torch import logging from torch.cuda.amp import autocast @@ -21,6 +22,7 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank @@ -452,6 +454,7 @@ class Paraformer(torch.nn.Module): is_use_lm = ( kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None ) + pred_timestamp = kwargs.get("pred_timestamp", False) if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) @@ -506,6 +509,7 @@ class Paraformer(torch.nn.Module): predictor_outs[2], predictor_outs[3], ) + pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] @@ -564,10 +568,22 @@ class Paraformer(torch.nn.Module): # Change integer-ids to tokens token = tokenizer.ids2tokens(token_int) text_postprocessed = tokenizer.tokens2text(token) - if not hasattr(tokenizer, "bpemodel"): - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - - result_i = {"key": key[i], "text": text_postprocessed} + + if pred_timestamp: + timestamp_str, timestamp = ts_prediction_lfr6_standard( + pre_peak_index[i], + alphas[i], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0), + upsample_rate=1, + ) + if not hasattr(tokenizer, "bpemodel"): + text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp) + result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,} + else: + if not hasattr(tokenizer, "bpemodel"): + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "text": text_postprocessed} if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 831d77357..af61e5a8f 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -29,13 +29,13 @@ def cif_wo_hidden(alphas, threshold): def ts_prediction_lfr6_standard( - us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True + us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3, ): if not len(char_list): return "", [] START_END_THRESHOLD = 5 - MAX_TOKEN_DURATION = 12 - TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled + MAX_TOKEN_DURATION = 12 # 3 times upsampled + TIME_RATE=10.0 * 6 / 1000 / upsample_rate if len(us_alphas.shape) == 2: alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only else: