mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
parent
500197b8ad
commit
74f4f7b4c7
@ -30,7 +30,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
|||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||||
|
|
||||||
import pdb
|
|
||||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
else:
|
else:
|
||||||
@ -99,6 +99,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
)
|
)
|
||||||
self.train_decoder = kwargs.get("train_decoder", False)
|
self.train_decoder = kwargs.get("train_decoder", False)
|
||||||
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
|
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
|
||||||
|
self.predictor_name = kwargs.get("predictor")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -170,6 +171,16 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
def _merge(self, cif_attended, dec_attended):
|
def _merge(self, cif_attended, dec_attended):
|
||||||
return cif_attended + dec_attended
|
return cif_attended + dec_attended
|
||||||
|
|
||||||
|
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||||
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
|
encoder_out.device)
|
||||||
|
predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
|
||||||
|
if len(predictor_outs) == 4:
|
||||||
|
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs
|
||||||
|
else:
|
||||||
|
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = predictor_outs
|
||||||
|
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||||
|
|
||||||
def _calc_seaco_loss(
|
def _calc_seaco_loss(
|
||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -248,7 +259,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
def _merge_res(dec_output, dha_output):
|
def _merge_res(dec_output, dha_output):
|
||||||
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
|
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
|
||||||
dha_ids = dha_output.max(-1)[-1]# [0]
|
dha_ids = dha_output.max(-1)[-1]# [0]
|
||||||
dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
|
dha_mask = (dha_ids == self.NO_BIAS).int().unsqueeze(-1)
|
||||||
a = (1 - lmbd) / lmbd
|
a = (1 - lmbd) / lmbd
|
||||||
b = 1 / lmbd
|
b = 1 / lmbd
|
||||||
a, b = a.to(dec_output.device), b.to(dec_output.device)
|
a, b = a.to(dec_output.device), b.to(dec_output.device)
|
||||||
@ -332,23 +343,28 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
if isinstance(encoder_out, tuple):
|
if isinstance(encoder_out, tuple):
|
||||||
encoder_out = encoder_out[0]
|
encoder_out = encoder_out[0]
|
||||||
|
|
||||||
|
|
||||||
# predictor
|
# predictor
|
||||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||||
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
|
pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
|
||||||
predictor_outs[2], predictor_outs[3]
|
|
||||||
pre_token_length = pre_token_length.round().long()
|
pre_token_length = pre_token_length.round().long()
|
||||||
if torch.max(pre_token_length) < 1:
|
if torch.max(pre_token_length) < 1:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
|
decoder_out = self._seaco_decode_with_ASF(encoder_out,
|
||||||
pre_acoustic_embeds,
|
encoder_out_lens,
|
||||||
pre_token_length,
|
pre_acoustic_embeds,
|
||||||
hw_list=self.hotword_list)
|
pre_token_length,
|
||||||
|
hw_list=self.hotword_list
|
||||||
|
)
|
||||||
|
|
||||||
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
|
if self.predictor_name == "CifPredictorV3":
|
||||||
pre_token_length)
|
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
pre_token_length)
|
||||||
|
else:
|
||||||
|
us_alphas = None
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
b, n, d = decoder_out.size()
|
b, n, d = decoder_out.size()
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
@ -393,23 +409,25 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
|
|||||||
# Change integer-ids to tokens
|
# Change integer-ids to tokens
|
||||||
token = tokenizer.ids2tokens(token_int)
|
token = tokenizer.ids2tokens(token_int)
|
||||||
text = tokenizer.tokens2text(token)
|
text = tokenizer.tokens2text(token)
|
||||||
|
if us_alphas is not None:
|
||||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
||||||
us_peaks[i][:encoder_out_lens[i] * 3],
|
us_peaks[i][:encoder_out_lens[i] * 3],
|
||||||
copy.copy(token),
|
copy.copy(token),
|
||||||
vad_offset=kwargs.get("begin_time", 0))
|
vad_offset=kwargs.get("begin_time", 0))
|
||||||
|
text_postprocessed, time_stamp_postprocessed, _ = \
|
||||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
|
postprocess_utils.sentence_postprocess(token, timestamp)
|
||||||
token, timestamp)
|
result_i = {"key": key[i], "text": text_postprocessed,
|
||||||
|
"timestamp": time_stamp_postprocessed}
|
||||||
result_i = {"key": key[i], "text": text_postprocessed,
|
if ibest_writer is not None:
|
||||||
"timestamp": time_stamp_postprocessed
|
ibest_writer["token"][key[i]] = " ".join(token)
|
||||||
}
|
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
|
||||||
|
ibest_writer["text"][key[i]] = text_postprocessed
|
||||||
if ibest_writer is not None:
|
else:
|
||||||
ibest_writer["token"][key[i]] = " ".join(token)
|
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||||
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
|
result_i = {"key": key[i], "text": text_postprocessed}
|
||||||
ibest_writer["text"][key[i]] = text_postprocessed
|
if ibest_writer is not None:
|
||||||
|
ibest_writer["token"][key[i]] = " ".join(token)
|
||||||
|
ibest_writer["text"][key[i]] = text_postprocessed
|
||||||
else:
|
else:
|
||||||
result_i = {"key": key[i], "token_int": token_int}
|
result_i = {"key": key[i], "token_int": token_int}
|
||||||
results.append(result_i)
|
results.append(result_i)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user