update repo

This commit is contained in:
嘉渊 2023-05-24 20:06:58 +08:00
parent 2ac90f9e11
commit 2efd780568

View File

@ -236,7 +236,7 @@ class Paraformer(FunASRModel):
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
if self.use_1st_decoder_loss and pre_loss_att is not None:
loss = loss + pre_loss_att
loss = loss + (1 - self.ctc_weight) * pre_loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None