From 1e1500adadf5c7ed3622efa0f48f51b48a78b31e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 20 May 2024 11:33:14 +0800 Subject: [PATCH] ds --- funasr/train_utils/trainer_ds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 78cfceb37..88a853c58 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -577,7 +577,7 @@ class Trainer: self.val_loss_avg = ( self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item() ) / (batch_idx + 1) - if "acc" in stats: + if "acc" in loss_dict["stats"]: self.val_acc_avg = ( self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item() ) / (batch_idx + 1) @@ -740,7 +740,7 @@ class Trainer: self.val_loss_avg = ( self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item() ) / (batch_idx + 1) - if "acc" in stats: + if "acc" in loss_dict["stats"]: self.val_acc_avg = ( self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()