From 72438f38052982ec6bc11a70b9c3e862d8a13582 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 6 Aug 2024 00:57:04 +0800 Subject: [PATCH] deepspeed --- funasr/train_utils/trainer_ds.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 9048a2f2d..d9f41b2bc 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -227,11 +227,11 @@ class Trainer: model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state) logging.info(f"\nCheckpoint saved to {filename}\n") - with torch.no_grad(): - model.save_checkpoint( - save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state - ) if not (step is None and epoch == 0): + with torch.no_grad(): + model.save_checkpoint( + save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state + ) if self.best_step_or_epoch == "": self.best_step_or_epoch = ckpt_name @@ -361,9 +361,10 @@ class Trainer: torch.save(state, filename) logging.info(f"\nCheckpoint saved to {filename}\n") - latest = Path(os.path.join(self.output_dir, f"model.pt")) - torch.save(state, latest) - if not (step is None and epoch != 0): + + if not (step is None and epoch == 0): + latest = Path(os.path.join(self.output_dir, f"model.pt")) + torch.save(state, latest) if self.best_step_or_epoch == "": self.best_step_or_epoch = ckpt_name