diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 5c8560d00..fcef34270 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -977,6 +977,59 @@ class BiCifParaformer(Paraformer): loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) return loss_pre2 + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_pad_lens = ys_pad_lens + self.predictor_bias + pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, + ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.sampling_ratio > 0.0: + if self.step_cur < 2: + logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + if self.step_cur < 2: + logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder( + encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens + ) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre def calc_predictor(self, encoder_out, encoder_out_lens):