update paraformer timestamp

This commit is contained in:
维石 2024-06-06 10:08:17 +08:00
parent f64bbaa036
commit ce6b70e479
2 changed files with 23 additions and 7 deletions

View File

@ -4,6 +4,7 @@
# MIT License (https://opensource.org/licenses/MIT)
import time
import copy
import torch
import logging
from torch.cuda.amp import autocast
@ -21,6 +22,7 @@ from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@ -452,6 +454,7 @@ class Paraformer(torch.nn.Module):
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
@ -506,6 +509,7 @@ class Paraformer(torch.nn.Module):
predictor_outs[2],
predictor_outs[3],
)
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
@ -564,10 +568,22 @@ class Paraformer(torch.nn.Module):
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text_postprocessed = tokenizer.tokens2text(token)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}
if pred_timestamp:
timestamp_str, timestamp = ts_prediction_lfr6_standard(
pre_peak_index[i],
alphas[i],
copy.copy(token),
vad_offset=kwargs.get("begin_time", 0),
upsample_rate=1,
)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
else:
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)

View File

@ -29,13 +29,13 @@ def cif_wo_hidden(alphas, threshold):
def ts_prediction_lfr6_standard(
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
MAX_TOKEN_DURATION = 12 # 3 times upsampled
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else: