From d878df49fdccebd21ce7752643b35d995bafcf55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 22 Feb 2024 13:08:14 +0800 Subject: [PATCH] v1.0.10 --- funasr/train_utils/load_pretrained_model.py | 9 +++++---- funasr/version.txt | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 5ba9bb7dc..8493bf58c 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -96,8 +96,7 @@ def load_pretrained_model( obj = model dst_state = obj.state_dict() - # import pdb; - # pdb.set_trace() + print(f"ckpt: {path}") if oss_bucket is None: src_state = torch.load(path, map_location=map_location) @@ -106,7 +105,9 @@ def load_pretrained_model( src_state = torch.load(buffer, map_location=map_location) if "state_dict" in src_state: src_state = src_state["state_dict"] - + + src_state = src_state["model"] if "model" in src_state else src_state + for k in dst_state.keys(): if not k.startswith("module.") and "module." + k in src_state.keys(): k_ddp = "module." + k @@ -115,7 +116,7 @@ def load_pretrained_model( if k_ddp in src_state: dst_state[k] = src_state[k_ddp] else: - print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") + print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}") flag = obj.load_state_dict(dst_state, strict=True) # print(flag) diff --git a/funasr/version.txt b/funasr/version.txt index b0f3d96f8..7ee7020b3 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -1.0.8 +1.0.10