This commit is contained in:
游雁 2024-02-21 11:30:59 +08:00
parent 4cf44a89f8
commit 497edf4c9d
3 changed files with 88 additions and 18 deletions

View File

@ -0,0 +1,12 @@
python funasr/bin/inference.py \
--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \
--config-name="config.yaml" \
++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \
++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \
++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \
++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \

View File

@ -75,6 +75,7 @@ def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
return assignment_map
def load_pretrained_model(
path: str,
model: torch.nn.Module,
@ -94,25 +95,69 @@ def load_pretrained_model(
"""
obj = model
dst_state = obj.state_dict()
# import pdb;
# pdb.set_trace()
print(f"ckpt: {path}")
if oss_bucket is None:
src_state = torch.load(path, map_location=map_location)
else:
buffer = BytesIO(oss_bucket.get_object(path).read())
src_state = torch.load(buffer, map_location=map_location)
src_state = src_state["model"] if "model" in src_state else src_state
if excludes is not None:
for e in excludes.split(","):
src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
dst_state = obj.state_dict()
src_state = assigment_scope_map(dst_state, src_state, scope_map)
if ignore_init_mismatch:
src_state = filter_state_dict(dst_state, src_state)
logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
dst_state.update(src_state)
obj.load_state_dict(dst_state, strict=True)
if "state_dict" in src_state:
src_state = src_state["state_dict"]
for k in dst_state.keys():
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
else:
k_ddp = k
if k_ddp in src_state:
dst_state[k] = src_state[k_ddp]
else:
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
flag = obj.load_state_dict(dst_state, strict=True)
print(flag)
# def load_pretrained_model(
# path: str,
# model: torch.nn.Module,
# ignore_init_mismatch: bool,
# map_location: str = "cpu",
# oss_bucket=None,
# scope_map=None,
# excludes=None,
# ):
# """Load a model state and set it to the model.
#
# Args:
# init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
#
# Examples:
#
# """
#
# obj = model
#
# if oss_bucket is None:
# src_state = torch.load(path, map_location=map_location)
# else:
# buffer = BytesIO(oss_bucket.get_object(path).read())
# src_state = torch.load(buffer, map_location=map_location)
# src_state = src_state["model"] if "model" in src_state else src_state
#
# if excludes is not None:
# for e in excludes.split(","):
# src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
#
# dst_state = obj.state_dict()
# src_state = assigment_scope_map(dst_state, src_state, scope_map)
#
# if ignore_init_mismatch:
# src_state = filter_state_dict(dst_state, src_state)
#
# logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
# logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
# dst_state.update(src_state)
# obj.load_state_dict(dst_state, strict=True)

View File

@ -128,7 +128,20 @@ class Trainer:
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict'])
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint['state_dict']
dst_state = self.model.state_dict()
for k in dst_state.keys():
if not k.startswith("module.") and "module."+k in src_state.keys():
k_ddp = "module."+k
else:
k_ddp = k
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
print(f"Checkpoint loaded successfully from '{ckpt}'")