This commit is contained in:
游雁 2024-06-12 15:18:42 +08:00
parent a56980a26f
commit 407625a734
2 changed files with 11 additions and 0 deletions

View File

@ -124,6 +124,7 @@ def main(**kwargs):
use_ddp=use_ddp,
use_fsdp=use_fsdp,
device=kwargs["device"],
excludes=kwargs.get("excludes", None),
output_dir=kwargs.get("output_dir", "./exp"),
**kwargs.get("train_conf"),
)

View File

@ -147,6 +147,10 @@ class Trainer:
self.use_deepspeed = use_deepspeed
self.deepspeed_config = kwargs.get("deepspeed_config", "")
self.excludes = kwargs.get("excludes", None)
if self.excludes is not None:
if isinstance(self.excludes, str):
self.excludes = self.excludes.split(",")
def save_checkpoint(
self,
@ -440,6 +444,12 @@ class Trainer:
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
if excludes is not None:
for k_ex in excludes:
k_tmp = k.replace("module.", "")
if k_tmp.startswith(k_ex):
logging.info(f"key: {{k}} matching: {k_ex}, excluded")
continue
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():