From 34b2682fba92bf4f8e163d22094e7de57e3290ca Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Fri, 24 Nov 2023 09:46:51 +0800 Subject: [PATCH] fix bug for contextual train --- funasr/models/e2e_asr_contextual_paraformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py index 171a6c60a..a2f7078ae 100644 --- a/funasr/models/e2e_asr_contextual_paraformer.py +++ b/funasr/models/e2e_asr_contextual_paraformer.py @@ -207,7 +207,7 @@ class NeatContextualParaformer(Paraformer): # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( - encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths, ideal_attn + encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths ) # 3. CTC-Att loss definition @@ -242,7 +242,6 @@ class NeatContextualParaformer(Paraformer): ys_pad_lens: torch.Tensor, hotword_pad: torch.Tensor, hotword_lengths: torch.Tensor, - ideal_attn: torch.Tensor, ): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device)