From 963ba1a7717c785d6e20ccb0d3cee9b59d5365e3 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Mon, 20 May 2024 17:11:41 +0800 Subject: [PATCH] Dev gzf deepspeed (#1737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * add * add * add * add * deepspeed * update with main (#1731) * c++ runtime adapt to 1.0 (#1724) * adapt vad runtime to 1.0 * add json * change yml name * add func LoadVocabFromJson * add token file for InitAsr * add token path for OfflineStream * add funcOpenYaml * add token file for InitPunc * add token file for stream * update punc-model * update funasr-wss-server * update runtime_sdk_download_tool.py * update docker list * Delete docs/images/wechat.png * Add files via upload * Emo2Vec限定选择的情感类别 (#1730) * 限定选择的情感类别 * 使用none来禁用情感标签输出 * 修改输出接口 * 使用unuse来禁用token --------- Co-authored-by: 常材 * bugfix * v1.0.27 * update docs * hf hub * Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680) --------- Co-authored-by: Yabin Li Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> * docs * docs * deepspeed * deepspeed * deepspeed * deepspeed * update * ds * ds * ds * ds * ds * ds * ds * add * add * bugfix --------- Co-authored-by: Yabin Li Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> --- funasr/train_utils/average_nbest_models.py | 9 ++++++++- funasr/train_utils/trainer_ds.py | 13 +++++++----- funasr/utils/misc.py | 23 ++++++++++++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py index 20da13022..67f1e55e8 100644 --- a/funasr/train_utils/average_nbest_models.py +++ b/funasr/train_utils/average_nbest_models.py @@ -22,7 +22,13 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, in the output directory. """ try: - checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") + if not use_deepspeed: + checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") + else: + checkpoint = torch.load( + os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"), + map_location="cpu", + ) avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"] val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"] sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True) @@ -35,6 +41,7 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, ckpt = os.path.join(output_dir, key) else: ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt") + checkpoint_paths.append(ckpt) except: print(f"{checkpoint} does not exist, avg the lastet checkpoint.") diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index bb9fca66c..1a553f812 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -15,6 +15,7 @@ from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.average_nbest_models import average_checkpoints from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +import funasr.utils.misc as misc_utils try: import wandb @@ -268,7 +269,8 @@ class Trainer: filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): - os.remove(filename) + # os.remove(filename) + misc_utils.smart_remove(filename) elif self.use_fsdp: pass @@ -360,7 +362,8 @@ class Trainer: filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): - os.remove(filename) + # os.remove(filename) + misc_utils.smart_remove(filename) if self.use_ddp or self.use_fsdp: dist.barrier() @@ -385,7 +388,7 @@ class Trainer: ckpt = os.path.join(self.output_dir, "model.pt") if os.path.exists(ckpt): _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt") - + self.start_epoch = checkpoint["epoch"] self.saved_ckpts = checkpoint["saved_ckpts"] self.val_acc_step_or_eoch = ( checkpoint["val_acc_step_or_eoch"] @@ -709,8 +712,8 @@ class Trainer: "data_split_i": kwargs.get("data_split_i", 0), "data_split_num": kwargs.get("data_split_num", 1), "log_step": batch_idx + kwargs.get("start_step", 0), - "batch_total": batch_idx, - "step_in_epoch": batch_idx, + "batch_total": batch_idx + 1, + "step_in_epoch": batch_idx + 1, "lr": 0.0, } diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py index 4613cb316..eb17f9723 100644 --- a/funasr/utils/misc.py +++ b/funasr/utils/misc.py @@ -94,3 +94,26 @@ def extract_filename_without_extension(file_path): filename, extension = os.path.splitext(filename_with_extension) # 返回不包含扩展名的文件名 return filename + + +def smart_remove(path): + """Intelligently removes files, empty directories, and non-empty directories recursively.""" + # Check if the provided path exists + if not os.path.exists(path): + print(f"{path} does not exist.") + return + + # If the path is a file, delete it + if os.path.isfile(path): + os.remove(path) + print(f"File {path} has been deleted.") + # If the path is a directory + elif os.path.isdir(path): + try: + # Attempt to remove an empty directory + os.rmdir(path) + print(f"Empty directory {path} has been deleted.") + except OSError: + # If the directory is not empty, remove it along with all its contents + shutil.rmtree(path) + print(f"Non-empty directory {path} has been recursively deleted.")