diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 78cfceb37..88a853c58 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -577,7 +577,7 @@ class Trainer: self.val_loss_avg = ( self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item() ) / (batch_idx + 1) - if "acc" in stats: + if "acc" in loss_dict["stats"]: self.val_acc_avg = ( self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item() ) / (batch_idx + 1) @@ -740,7 +740,7 @@ class Trainer: self.val_loss_avg = ( self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item() ) / (batch_idx + 1) - if "acc" in stats: + if "acc" in loss_dict["stats"]: self.val_acc_avg = ( self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()