mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update bicifparaformer forward
This commit is contained in:
parent
cf8646cd92
commit
2d65e5e754
@ -978,6 +978,59 @@ class BiCifParaformer(Paraformer):
|
|||||||
|
|
||||||
return loss_pre2
|
return loss_pre2
|
||||||
|
|
||||||
|
def _calc_att_loss(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
ys_pad: torch.Tensor,
|
||||||
|
ys_pad_lens: torch.Tensor,
|
||||||
|
):
|
||||||
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
|
encoder_out.device)
|
||||||
|
if self.predictor_bias == 1:
|
||||||
|
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||||
|
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||||
|
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||||
|
ignore_id=self.ignore_id)
|
||||||
|
|
||||||
|
# 0. sampler
|
||||||
|
decoder_out_1st = None
|
||||||
|
if self.sampling_ratio > 0.0:
|
||||||
|
if self.step_cur < 2:
|
||||||
|
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||||
|
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||||
|
pre_acoustic_embeds)
|
||||||
|
else:
|
||||||
|
if self.step_cur < 2:
|
||||||
|
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||||
|
sematic_embeds = pre_acoustic_embeds
|
||||||
|
|
||||||
|
# 1. Forward decoder
|
||||||
|
decoder_outs = self.decoder(
|
||||||
|
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
|
||||||
|
)
|
||||||
|
decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||||
|
|
||||||
|
if decoder_out_1st is None:
|
||||||
|
decoder_out_1st = decoder_out
|
||||||
|
# 2. Compute attention loss
|
||||||
|
loss_att = self.criterion_att(decoder_out, ys_pad)
|
||||||
|
acc_att = th_accuracy(
|
||||||
|
decoder_out_1st.view(-1, self.vocab_size),
|
||||||
|
ys_pad,
|
||||||
|
ignore_label=self.ignore_id,
|
||||||
|
)
|
||||||
|
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||||
|
|
||||||
|
# Compute cer/wer using attention-decoder
|
||||||
|
if self.training or self.error_calculator is None:
|
||||||
|
cer_att, wer_att = None, None
|
||||||
|
else:
|
||||||
|
ys_hat = decoder_out_1st.argmax(dim=-1)
|
||||||
|
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||||
|
|
||||||
|
return loss_att, acc_att, cer_att, wer_att, loss_pre
|
||||||
|
|
||||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||||
|
|
||||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user