From 29fa4e47899d53d34b68523a901c12f2f339214b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 7 Aug 2024 12:48:02 +0800 Subject: [PATCH] deepspeed --- funasr/train_utils/load_pretrained_model.py | 58 ++++++++++----------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 54f3a871b..e8b070100 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -35,41 +35,37 @@ def load_pretrained_model( logging.info(f"ckpt: {path}, use_deepspeed: {use_deepspeed}") - if oss_bucket is None: - if use_deepspeed: - ckpt_dir = os.path.dirname(path) - ckpt_name = os.path.basename(path) - if os.path.exists(f"{ckpt_dir}/zero_to_fp32.py"): - print("Detect zero_to_fp32, begin to convert fp32 model") - ckpt_fp32 = f"{ckpt_dir}/{ckpt_name[3:]}" - if os.path.exists(ckpt_fp32): - print(f"Detect zero_to_fp32 already exist! Loading it directly. {ckpt_fp32}") - src_state = torch.load(ckpt_fp32, map_location=map_location) - else: - with open(f"{ckpt_dir}/latest", "w") as latest: - latest.write(ckpt_name) - latest.flush() - from deepspeed.utils.zero_to_fp32 import ( - get_fp32_state_dict_from_zero_checkpoint, - ) - - src_state = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir) # already on cpu - if kwargs.get("save_deepspeed_zero_fp32", False): - print( - f'save_deepspeed_zero_fp32: {kwargs.get("save_deepspeed_zero_fp32", False)}, {ckpt_fp32}' - ) - torch.save({"state_dict": src_state}, ckpt_fp32) + if use_deepspeed and os.path.isdir(path): + ckpt_dir = os.path.dirname(path) + ckpt_name = os.path.basename(path) + if os.path.exists(f"{ckpt_dir}/zero_to_fp32.py"): + print("Detect zero_to_fp32, begin to convert fp32 model") + ckpt_fp32 = f"{ckpt_dir}/{ckpt_name[3:]}" + if os.path.exists(ckpt_fp32): + print(f"Detect zero_to_fp32 already exist! Loading it directly. {ckpt_fp32}") + src_state = torch.load(ckpt_fp32, map_location=map_location) else: - print("Detect deepspeed without zero, load fp32 model directly") - for item in os.listdir(path): - if item.endswith(".pt"): - src_state = torch.load(f"{path}/{item}", map_location=map_location) + with open(f"{ckpt_dir}/latest", "w") as latest: + latest.write(ckpt_name) + latest.flush() + from deepspeed.utils.zero_to_fp32 import ( + get_fp32_state_dict_from_zero_checkpoint, + ) + src_state = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir) # already on cpu + if kwargs.get("save_deepspeed_zero_fp32", False): + print( + f'save_deepspeed_zero_fp32: {kwargs.get("save_deepspeed_zero_fp32", False)}, {ckpt_fp32}' + ) + torch.save({"state_dict": src_state}, ckpt_fp32) else: - src_state = torch.load(path, map_location=map_location) + print("Detect deepspeed without zero, load fp32 model directly") + for item in os.listdir(path): + if item.endswith(".pt"): + src_state = torch.load(f"{path}/{item}", map_location=map_location) + else: - buffer = BytesIO(oss_bucket.get_object(path).read()) - src_state = torch.load(buffer, map_location=map_location) + src_state = torch.load(path, map_location=map_location) src_state = src_state["state_dict"] if "state_dict" in src_state else src_state src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state