From e6fe602db3eb1209543e55f1aafa2932dfda3310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 10 Jan 2025 10:14:30 +0800 Subject: [PATCH] step_or_epoch bugfix --- funasr/train_utils/trainer_ds.py | 76 ++++++++++++++++---------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 85513a5a7..0b104da6e 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -122,8 +122,8 @@ class Trainer: self.saved_ckpts = {} self.step_or_epoch = -1 self.best_step_or_epoch = "" - self.val_acc_step_or_eoch = {} - self.val_loss_step_or_eoch = {} + self.val_acc_step_or_epoch = {} + self.val_loss_step_or_epoch = {} self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False) self.start_data_split_i = 0 @@ -195,8 +195,8 @@ class Trainer: # "optimizer": optim.state_dict(), # "scheduler": scheduler.state_dict(), "saved_ckpts": self.saved_ckpts, - "val_acc_step_or_eoch": self.val_acc_step_or_eoch, - "val_loss_step_or_eoch": self.val_loss_step_or_eoch, + "val_acc_step_or_epoch": self.val_acc_step_or_epoch, + "val_loss_step_or_epoch": self.val_loss_step_or_epoch, "best_step_or_epoch": self.best_step_or_epoch, "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type, "step": step, @@ -234,8 +234,8 @@ class Trainer: if self.avg_keep_nbest_models_type == "acc": if ( - self.val_acc_step_or_eoch[ckpt_name] - >= self.val_acc_step_or_eoch[self.best_step_or_epoch] + self.val_acc_step_or_epoch[ckpt_name] + >= self.val_acc_step_or_epoch[self.best_step_or_epoch] ): self.best_step_or_epoch = ckpt_name best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best")) @@ -245,16 +245,16 @@ class Trainer: save_dir=self.output_dir, tag=f"model.pt.best", client_state=state ) logging.info( - f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" + f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" ) else: logging.info( - f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" + f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" ) elif self.avg_keep_nbest_models_type == "loss": if ( - self.val_loss_step_or_eoch[ckpt_name] - <= self.val_loss_step_or_eoch[self.best_step_or_epoch] + self.val_loss_step_or_epoch[ckpt_name] + <= self.val_loss_step_or_epoch[self.best_step_or_epoch] ): self.best_step_or_epoch = ckpt_name best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best")) @@ -264,16 +264,16 @@ class Trainer: save_dir=self.output_dir, tag=f"model.pt.best", client_state=state ) logging.info( - f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" + f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" ) else: logging.info( - f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" + f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" ) else: print("Undo") self.saved_ckpts[ckpt_name] = getattr( - self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch" + self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch" )[ckpt_name] if self.keep_nbest_models > 0: if len(self.saved_ckpts) > self.keep_nbest_models: @@ -301,8 +301,8 @@ class Trainer: "optimizer": optim.state_dict(), "scheduler": scheduler.state_dict(), "saved_ckpts": self.saved_ckpts, - "val_acc_step_or_eoch": self.val_acc_step_or_eoch, - "val_loss_step_or_eoch": self.val_loss_step_or_eoch, + "val_acc_step_or_epoch": self.val_acc_step_or_epoch, + "val_loss_step_or_epoch": self.val_loss_step_or_epoch, "best_step_or_epoch": self.best_step_or_epoch, "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type, "step": step, @@ -353,38 +353,38 @@ class Trainer: if self.avg_keep_nbest_models_type == "acc": if ( - self.val_acc_step_or_eoch[ckpt_name] - >= self.val_acc_step_or_eoch[self.best_step_or_epoch] + self.val_acc_step_or_epoch[ckpt_name] + >= self.val_acc_step_or_epoch[self.best_step_or_epoch] ): self.best_step_or_epoch = ckpt_name best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best")) torch.save(state, best_ckpt) logging.info( - f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" + f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" ) else: logging.info( - f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" + f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" ) elif self.avg_keep_nbest_models_type == "loss": if ( - self.val_loss_step_or_eoch[ckpt_name] - <= self.val_loss_step_or_eoch[self.best_step_or_epoch] + self.val_loss_step_or_epoch[ckpt_name] + <= self.val_loss_step_or_epoch[self.best_step_or_epoch] ): self.best_step_or_epoch = ckpt_name best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best")) torch.save(state, best_ckpt) logging.info( - f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" + f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}" ) else: logging.info( - f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" + f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}" ) else: print("Undo") self.saved_ckpts[ckpt_name] = getattr( - self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch" + self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch" )[ckpt_name] if self.keep_nbest_models > 0: if len(self.saved_ckpts) > self.keep_nbest_models: @@ -425,14 +425,14 @@ class Trainer: _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt") self.start_epoch = checkpoint["epoch"] self.saved_ckpts = checkpoint["saved_ckpts"] - self.val_acc_step_or_eoch = ( - checkpoint["val_acc_step_or_eoch"] - if "val_acc_step_or_eoch" in checkpoint + self.val_acc_step_or_epoch = ( + checkpoint["val_acc_step_or_epoch"] + if "val_acc_step_or_epoch" in checkpoint else {} ) - self.val_loss_step_or_eoch = ( - checkpoint["val_loss_step_or_eoch"] - if "val_loss_step_or_eoch" in checkpoint + self.val_loss_step_or_epoch = ( + checkpoint["val_loss_step_or_epoch"] + if "val_loss_step_or_epoch" in checkpoint else {} ) self.best_step_or_epoch = ( @@ -501,14 +501,14 @@ class Trainer: scaler.load_state_dict(checkpoint["scaler_state"]) self.saved_ckpts = checkpoint["saved_ckpts"] - self.val_acc_step_or_eoch = ( - checkpoint["val_acc_step_or_eoch"] - if "val_acc_step_or_eoch" in checkpoint + self.val_acc_step_or_epoch = ( + checkpoint["val_acc_step_or_epoch"] + if "val_acc_step_or_epoch" in checkpoint else {} ) - self.val_loss_step_or_eoch = ( - checkpoint["val_loss_step_or_eoch"] - if "val_loss_step_or_eoch" in checkpoint + self.val_loss_step_or_epoch = ( + checkpoint["val_loss_step_or_epoch"] + if "val_loss_step_or_epoch" in checkpoint else {} ) self.best_step_or_epoch = ( @@ -803,8 +803,8 @@ class Trainer: ckpt_name = f"model.pt.ep{epoch}" else: ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}' - self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg - self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg + self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg + self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg if self.use_ddp or self.use_fsdp or self.use_deepspeed: dist.barrier()