diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 924127117..82acef2ae 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -236,7 +236,7 @@ class Paraformer(FunASRModel): loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight if self.use_1st_decoder_loss and pre_loss_att is not None: - loss = loss + pre_loss_att + loss = loss + (1 - self.ctc_weight) * pre_loss_att # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None