mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf deepspeed (#1737)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * add * add * add * add * deepspeed * update with main (#1731) * c++ runtime adapt to 1.0 (#1724) * adapt vad runtime to 1.0 * add json * change yml name * add func LoadVocabFromJson * add token file for InitAsr * add token path for OfflineStream * add funcOpenYaml * add token file for InitPunc * add token file for stream * update punc-model * update funasr-wss-server * update runtime_sdk_download_tool.py * update docker list * Delete docs/images/wechat.png * Add files via upload * Emo2Vec限定选择的情感类别 (#1730) * 限定选择的情感类别 * 使用none来禁用情感标签输出 * 修改输出接口 * 使用unuse来禁用token --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> * bugfix * v1.0.27 * update docs * hf hub * Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680) --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> * docs * docs * deepspeed * deepspeed * deepspeed * deepspeed * update * ds * ds * ds * ds * ds * ds * ds * add * add * bugfix --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>
This commit is contained in:
parent
961ec280af
commit
963ba1a771
@ -22,7 +22,13 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False,
|
||||
in the output directory.
|
||||
"""
|
||||
try:
|
||||
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
|
||||
if not use_deepspeed:
|
||||
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
|
||||
map_location="cpu",
|
||||
)
|
||||
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
|
||||
val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
|
||||
sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
|
||||
@ -35,6 +41,7 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False,
|
||||
ckpt = os.path.join(output_dir, key)
|
||||
else:
|
||||
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
|
||||
checkpoint_paths.append(ckpt)
|
||||
|
||||
except:
|
||||
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
|
||||
|
||||
@ -15,6 +15,7 @@ from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
import funasr.utils.misc as misc_utils
|
||||
|
||||
try:
|
||||
import wandb
|
||||
@ -268,7 +269,8 @@ class Trainer:
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
elif self.use_fsdp:
|
||||
pass
|
||||
@ -360,7 +362,8 @@ class Trainer:
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
@ -385,7 +388,7 @@ class Trainer:
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
if os.path.exists(ckpt):
|
||||
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
|
||||
|
||||
self.start_epoch = checkpoint["epoch"]
|
||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||
self.val_acc_step_or_eoch = (
|
||||
checkpoint["val_acc_step_or_eoch"]
|
||||
@ -709,8 +712,8 @@ class Trainer:
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"log_step": batch_idx + kwargs.get("start_step", 0),
|
||||
"batch_total": batch_idx,
|
||||
"step_in_epoch": batch_idx,
|
||||
"batch_total": batch_idx + 1,
|
||||
"step_in_epoch": batch_idx + 1,
|
||||
"lr": 0.0,
|
||||
}
|
||||
|
||||
|
||||
@ -94,3 +94,26 @@ def extract_filename_without_extension(file_path):
|
||||
filename, extension = os.path.splitext(filename_with_extension)
|
||||
# 返回不包含扩展名的文件名
|
||||
return filename
|
||||
|
||||
|
||||
def smart_remove(path):
|
||||
"""Intelligently removes files, empty directories, and non-empty directories recursively."""
|
||||
# Check if the provided path exists
|
||||
if not os.path.exists(path):
|
||||
print(f"{path} does not exist.")
|
||||
return
|
||||
|
||||
# If the path is a file, delete it
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
print(f"File {path} has been deleted.")
|
||||
# If the path is a directory
|
||||
elif os.path.isdir(path):
|
||||
try:
|
||||
# Attempt to remove an empty directory
|
||||
os.rmdir(path)
|
||||
print(f"Empty directory {path} has been deleted.")
|
||||
except OSError:
|
||||
# If the directory is not empty, remove it along with all its contents
|
||||
shutil.rmtree(path)
|
||||
print(f"Non-empty directory {path} has been recursively deleted.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user