mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update repo
This commit is contained in:
parent
2ac90f9e11
commit
2efd780568
@ -236,7 +236,7 @@ class Paraformer(FunASRModel):
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
||||
|
||||
if self.use_1st_decoder_loss and pre_loss_att is not None:
|
||||
loss = loss + pre_loss_att
|
||||
loss = loss + (1 - self.ctc_weight) * pre_loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user