mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf deepspeed (#1737)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * add * add * add * add * deepspeed * update with main (#1731) * c++ runtime adapt to 1.0 (#1724) * adapt vad runtime to 1.0 * add json * change yml name * add func LoadVocabFromJson * add token file for InitAsr * add token path for OfflineStream * add funcOpenYaml * add token file for InitPunc * add token file for stream * update punc-model * update funasr-wss-server * update runtime_sdk_download_tool.py * update docker list * Delete docs/images/wechat.png * Add files via upload * Emo2Vec限定选择的情感类别 (#1730) * 限定选择的情感类别 * 使用none来禁用情感标签输出 * 修改输出接口 * 使用unuse来禁用token --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> * bugfix * v1.0.27 * update docs * hf hub * Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680) --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> * docs * docs * deepspeed * deepspeed * deepspeed * deepspeed * update * ds * ds * ds * ds * ds * ds * ds * add * add * bugfix --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>
This commit is contained in:
parent
961ec280af
commit
963ba1a771
@ -22,7 +22,13 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False,
|
|||||||
in the output directory.
|
in the output directory.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
|
if not use_deepspeed:
|
||||||
|
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(
|
||||||
|
os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
|
||||||
|
map_location="cpu",
|
||||||
|
)
|
||||||
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
|
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
|
||||||
val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
|
val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
|
||||||
sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
|
sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
|
||||||
@ -35,6 +41,7 @@ def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False,
|
|||||||
ckpt = os.path.join(output_dir, key)
|
ckpt = os.path.join(output_dir, key)
|
||||||
else:
|
else:
|
||||||
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
|
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
|
||||||
|
checkpoint_paths.append(ckpt)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
|
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from funasr.train_utils.device_funcs import to_device
|
|||||||
from funasr.train_utils.recursive_op import recursive_average
|
from funasr.train_utils.recursive_op import recursive_average
|
||||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||||
|
import funasr.utils.misc as misc_utils
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
@ -268,7 +269,8 @@ class Trainer:
|
|||||||
filename = os.path.join(self.output_dir, key)
|
filename = os.path.join(self.output_dir, key)
|
||||||
logging.info(f"Delete: {filename}")
|
logging.info(f"Delete: {filename}")
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
os.remove(filename)
|
# os.remove(filename)
|
||||||
|
misc_utils.smart_remove(filename)
|
||||||
|
|
||||||
elif self.use_fsdp:
|
elif self.use_fsdp:
|
||||||
pass
|
pass
|
||||||
@ -360,7 +362,8 @@ class Trainer:
|
|||||||
filename = os.path.join(self.output_dir, key)
|
filename = os.path.join(self.output_dir, key)
|
||||||
logging.info(f"Delete: {filename}")
|
logging.info(f"Delete: {filename}")
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
os.remove(filename)
|
# os.remove(filename)
|
||||||
|
misc_utils.smart_remove(filename)
|
||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@ -385,7 +388,7 @@ class Trainer:
|
|||||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||||
if os.path.exists(ckpt):
|
if os.path.exists(ckpt):
|
||||||
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
|
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
|
||||||
|
self.start_epoch = checkpoint["epoch"]
|
||||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||||
self.val_acc_step_or_eoch = (
|
self.val_acc_step_or_eoch = (
|
||||||
checkpoint["val_acc_step_or_eoch"]
|
checkpoint["val_acc_step_or_eoch"]
|
||||||
@ -709,8 +712,8 @@ class Trainer:
|
|||||||
"data_split_i": kwargs.get("data_split_i", 0),
|
"data_split_i": kwargs.get("data_split_i", 0),
|
||||||
"data_split_num": kwargs.get("data_split_num", 1),
|
"data_split_num": kwargs.get("data_split_num", 1),
|
||||||
"log_step": batch_idx + kwargs.get("start_step", 0),
|
"log_step": batch_idx + kwargs.get("start_step", 0),
|
||||||
"batch_total": batch_idx,
|
"batch_total": batch_idx + 1,
|
||||||
"step_in_epoch": batch_idx,
|
"step_in_epoch": batch_idx + 1,
|
||||||
"lr": 0.0,
|
"lr": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -94,3 +94,26 @@ def extract_filename_without_extension(file_path):
|
|||||||
filename, extension = os.path.splitext(filename_with_extension)
|
filename, extension = os.path.splitext(filename_with_extension)
|
||||||
# 返回不包含扩展名的文件名
|
# 返回不包含扩展名的文件名
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def smart_remove(path):
|
||||||
|
"""Intelligently removes files, empty directories, and non-empty directories recursively."""
|
||||||
|
# Check if the provided path exists
|
||||||
|
if not os.path.exists(path):
|
||||||
|
print(f"{path} does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the path is a file, delete it
|
||||||
|
if os.path.isfile(path):
|
||||||
|
os.remove(path)
|
||||||
|
print(f"File {path} has been deleted.")
|
||||||
|
# If the path is a directory
|
||||||
|
elif os.path.isdir(path):
|
||||||
|
try:
|
||||||
|
# Attempt to remove an empty directory
|
||||||
|
os.rmdir(path)
|
||||||
|
print(f"Empty directory {path} has been deleted.")
|
||||||
|
except OSError:
|
||||||
|
# If the directory is not empty, remove it along with all its contents
|
||||||
|
shutil.rmtree(path)
|
||||||
|
print(f"Non-empty directory {path} has been recursively deleted.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user