mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
parent
500197b8ad
commit
74f4f7b4c7
@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
import pdb
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -99,6 +99,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
)
|
||||
self.train_decoder = kwargs.get("train_decoder", False)
|
||||
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
|
||||
self.predictor_name = kwargs.get("predictor")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -170,6 +171,16 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
def _merge(self, cif_attended, dec_attended):
|
||||
return cif_attended + dec_attended
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
|
||||
if len(predictor_outs) == 4:
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs
|
||||
else:
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = predictor_outs
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def _calc_seaco_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -248,7 +259,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
def _merge_res(dec_output, dha_output):
|
||||
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
|
||||
dha_ids = dha_output.max(-1)[-1]# [0]
|
||||
dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
|
||||
dha_mask = (dha_ids == self.NO_BIAS).int().unsqueeze(-1)
|
||||
a = (1 - lmbd) / lmbd
|
||||
b = 1 / lmbd
|
||||
a, b = a.to(dec_output.device), b.to(dec_output.device)
|
||||
@ -332,23 +343,28 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
|
||||
predictor_outs[2], predictor_outs[3]
|
||||
pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
|
||||
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
hw_list=self.hotword_list)
|
||||
decoder_out = self._seaco_decode_with_ASF(encoder_out,
|
||||
encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
hw_list=self.hotword_list
|
||||
)
|
||||
|
||||
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
|
||||
pre_token_length)
|
||||
if self.predictor_name == "CifPredictorV3":
|
||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out,
|
||||
encoder_out_lens,
|
||||
pre_token_length)
|
||||
else:
|
||||
us_alphas = None
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
for i in range(b):
|
||||
@ -393,23 +409,25 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
||||
us_peaks[i][:encoder_out_lens[i] * 3],
|
||||
copy.copy(token),
|
||||
vad_offset=kwargs.get("begin_time", 0))
|
||||
|
||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
|
||||
token, timestamp)
|
||||
|
||||
result_i = {"key": key[i], "text": text_postprocessed,
|
||||
"timestamp": time_stamp_postprocessed
|
||||
}
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
if us_alphas is not None:
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
||||
us_peaks[i][:encoder_out_lens[i] * 3],
|
||||
copy.copy(token),
|
||||
vad_offset=kwargs.get("begin_time", 0))
|
||||
text_postprocessed, time_stamp_postprocessed, _ = \
|
||||
postprocess_utils.sentence_postprocess(token, timestamp)
|
||||
result_i = {"key": key[i], "text": text_postprocessed,
|
||||
"timestamp": time_stamp_postprocessed}
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
else:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "text": text_postprocessed}
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
else:
|
||||
result_i = {"key": key[i], "token_int": token_int}
|
||||
results.append(result_i)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user