diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py index 0f0880425..b10c23145 100644 --- a/funasr/train_utils/average_nbest_models.py +++ b/funasr/train_utils/average_nbest_models.py @@ -62,7 +62,8 @@ def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs): # Check if we have any state_dicts to average if len(state_dicts) < 1: - raise RuntimeError("No checkpoints found for averaging.") + print("No checkpoints found for averaging.") + return # Average or sum weights avg_state_dict = OrderedDict() diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 8a52746d0..fead9ca08 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -168,8 +168,7 @@ class Trainer: """ step_in_epoch = None if step is None else step_in_epoch if self.use_deepspeed: - with torch.no_grad(): - model.save_checkpoint(save_dir=model_dir, tag=tag, client_state=info_dict) + logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n") # self.step_or_epoch += 1 state = { @@ -273,8 +272,7 @@ class Trainer: elif self.use_fsdp: pass - step_in_epoch = None if step is None else step_in_epoch - if self.rank == 0: + elif self.rank == 0: logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n") # self.step_or_epoch += 1 state = { @@ -385,7 +383,7 @@ class Trainer: if self.use_deepspeed: ckpt = os.path.join(self.output_dir, "model.pt") - if os.path.isfile(ckpt): + if os.path.exists(ckpt): _, checkpoint = model_engine.load_checkpoint(self.output_dir, "model.pt") self.saved_ckpts = checkpoint["saved_ckpts"] @@ -712,7 +710,7 @@ class Trainer: "data_split_num": kwargs.get("data_split_num", 1), "log_step": batch_idx + kwargs.get("start_step", 0), "batch_total": batch_idx, - "step_in_epoch": step_in_epoch, + "step_in_epoch": batch_idx, "lr": 0.0, }