mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #1788 from modelscope/dev_sx2
update paraformer timestamp
This commit is contained in:
commit
c1b5583000
@ -4,6 +4,7 @@
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import time
|
||||
import copy
|
||||
import torch
|
||||
import logging
|
||||
from torch.cuda.amp import autocast
|
||||
@ -21,6 +22,7 @@ from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
|
||||
@ -452,6 +454,7 @@ class Paraformer(torch.nn.Module):
|
||||
is_use_lm = (
|
||||
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
)
|
||||
pred_timestamp = kwargs.get("pred_timestamp", False)
|
||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
@ -506,6 +509,7 @@ class Paraformer(torch.nn.Module):
|
||||
predictor_outs[2],
|
||||
predictor_outs[3],
|
||||
)
|
||||
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
@ -564,10 +568,22 @@ class Paraformer(torch.nn.Module):
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text_postprocessed = tokenizer.tokens2text(token)
|
||||
if not hasattr(tokenizer, "bpemodel"):
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
|
||||
result_i = {"key": key[i], "text": text_postprocessed}
|
||||
if pred_timestamp:
|
||||
timestamp_str, timestamp = ts_prediction_lfr6_standard(
|
||||
pre_peak_index[i],
|
||||
alphas[i],
|
||||
copy.copy(token),
|
||||
vad_offset=kwargs.get("begin_time", 0),
|
||||
upsample_rate=1,
|
||||
)
|
||||
if not hasattr(tokenizer, "bpemodel"):
|
||||
text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
|
||||
result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
|
||||
else:
|
||||
if not hasattr(tokenizer, "bpemodel"):
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "text": text_postprocessed}
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
|
||||
@ -29,13 +29,13 @@ def cif_wo_hidden(alphas, threshold):
|
||||
|
||||
|
||||
def ts_prediction_lfr6_standard(
|
||||
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
|
||||
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
|
||||
):
|
||||
if not len(char_list):
|
||||
return "", []
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 12
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
MAX_TOKEN_DURATION = 12 # 3 times upsampled
|
||||
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
|
||||
if len(us_alphas.shape) == 2:
|
||||
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user