diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index a0cd0df0d..9048a2f2d 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -231,7 +231,7 @@ class Trainer: model.save_checkpoint( save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state ) - if not (step is None and epoch != 0): + if not (step is None and epoch == 0): if self.best_step_or_epoch == "": self.best_step_or_epoch = ckpt_name