diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py index 20da13022..67f1e55e8 100644 --- a/funasr/train_utils/average_nbest_models.py +++ b/funasr/train_utils/average_nbest_models.py @@ -22,7 +22,13 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, in the output directory. """ try: - checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") + if not use_deepspeed: + checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") + else: + checkpoint = torch.load( + os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"), + map_location="cpu", + ) avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"] val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"] sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True) @@ -35,6 +41,7 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, ckpt = os.path.join(output_dir, key) else: ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt") + checkpoint_paths.append(ckpt) except: print(f"{checkpoint} does not exist, avg the lastet checkpoint.") diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index bb9fca66c..1a553f812 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -15,6 +15,7 @@ from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.average_nbest_models import average_checkpoints from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +import funasr.utils.misc as misc_utils try: import wandb @@ -268,7 +269,8 @@ class Trainer: filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): - os.remove(filename) + # os.remove(filename) + misc_utils.smart_remove(filename) elif self.use_fsdp: pass @@ -360,7 +362,8 @@ class Trainer: filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): - os.remove(filename) + # os.remove(filename) + misc_utils.smart_remove(filename) if self.use_ddp or self.use_fsdp: dist.barrier() @@ -385,7 +388,7 @@ class Trainer: ckpt = os.path.join(self.output_dir, "model.pt") if os.path.exists(ckpt): _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt") - + self.start_epoch = checkpoint["epoch"] self.saved_ckpts = checkpoint["saved_ckpts"] self.val_acc_step_or_eoch = ( checkpoint["val_acc_step_or_eoch"] @@ -709,8 +712,8 @@ class Trainer: "data_split_i": kwargs.get("data_split_i", 0), "data_split_num": kwargs.get("data_split_num", 1), "log_step": batch_idx + kwargs.get("start_step", 0), - "batch_total": batch_idx, - "step_in_epoch": batch_idx, + "batch_total": batch_idx + 1, + "step_in_epoch": batch_idx + 1, "lr": 0.0, } diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py index 4613cb316..eb17f9723 100644 --- a/funasr/utils/misc.py +++ b/funasr/utils/misc.py @@ -94,3 +94,26 @@ def extract_filename_without_extension(file_path): filename, extension = os.path.splitext(filename_with_extension) # 返回不包含扩展名的文件名 return filename + + +def smart_remove(path): + """Intelligently removes files, empty directories, and non-empty directories recursively.""" + # Check if the provided path exists + if not os.path.exists(path): + print(f"{path} does not exist.") + return + + # If the path is a file, delete it + if os.path.isfile(path): + os.remove(path) + print(f"File {path} has been deleted.") + # If the path is a directory + elif os.path.isdir(path): + try: + # Attempt to remove an empty directory + os.rmdir(path) + print(f"Empty directory {path} has been deleted.") + except OSError: + # If the directory is not empty, remove it along with all its contents + shutil.rmtree(path) + print(f"Non-empty directory {path} has been recursively deleted.")