mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
bugfix
This commit is contained in:
parent
4cf44a89f8
commit
497edf4c9d
12
examples/aishell/conformer/infer.sh
Normal file
12
examples/aishell/conformer/infer.sh
Normal file
@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \
|
||||
--config-name="config.yaml" \
|
||||
++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \
|
||||
++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \
|
||||
++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \
|
||||
++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \
|
||||
++output_dir="./outputs/debug" \
|
||||
++device="cpu" \
|
||||
|
||||
@ -75,6 +75,7 @@ def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
|
||||
|
||||
return assignment_map
|
||||
|
||||
|
||||
def load_pretrained_model(
|
||||
path: str,
|
||||
model: torch.nn.Module,
|
||||
@ -94,25 +95,69 @@ def load_pretrained_model(
|
||||
"""
|
||||
|
||||
obj = model
|
||||
|
||||
dst_state = obj.state_dict()
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
print(f"ckpt: {path}")
|
||||
if oss_bucket is None:
|
||||
src_state = torch.load(path, map_location=map_location)
|
||||
else:
|
||||
buffer = BytesIO(oss_bucket.get_object(path).read())
|
||||
src_state = torch.load(buffer, map_location=map_location)
|
||||
src_state = src_state["model"] if "model" in src_state else src_state
|
||||
|
||||
if excludes is not None:
|
||||
for e in excludes.split(","):
|
||||
src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
||||
|
||||
dst_state = obj.state_dict()
|
||||
src_state = assigment_scope_map(dst_state, src_state, scope_map)
|
||||
|
||||
if ignore_init_mismatch:
|
||||
src_state = filter_state_dict(dst_state, src_state)
|
||||
|
||||
logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
||||
dst_state.update(src_state)
|
||||
obj.load_state_dict(dst_state, strict=True)
|
||||
if "state_dict" in src_state:
|
||||
src_state = src_state["state_dict"]
|
||||
|
||||
for k in dst_state.keys():
|
||||
if not k.startswith("module.") and "module." + k in src_state.keys():
|
||||
k_ddp = "module." + k
|
||||
else:
|
||||
k_ddp = k
|
||||
if k_ddp in src_state:
|
||||
dst_state[k] = src_state[k_ddp]
|
||||
else:
|
||||
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
||||
|
||||
flag = obj.load_state_dict(dst_state, strict=True)
|
||||
print(flag)
|
||||
|
||||
# def load_pretrained_model(
|
||||
# path: str,
|
||||
# model: torch.nn.Module,
|
||||
# ignore_init_mismatch: bool,
|
||||
# map_location: str = "cpu",
|
||||
# oss_bucket=None,
|
||||
# scope_map=None,
|
||||
# excludes=None,
|
||||
# ):
|
||||
# """Load a model state and set it to the model.
|
||||
#
|
||||
# Args:
|
||||
# init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
||||
#
|
||||
# Examples:
|
||||
#
|
||||
# """
|
||||
#
|
||||
# obj = model
|
||||
#
|
||||
# if oss_bucket is None:
|
||||
# src_state = torch.load(path, map_location=map_location)
|
||||
# else:
|
||||
# buffer = BytesIO(oss_bucket.get_object(path).read())
|
||||
# src_state = torch.load(buffer, map_location=map_location)
|
||||
# src_state = src_state["model"] if "model" in src_state else src_state
|
||||
#
|
||||
# if excludes is not None:
|
||||
# for e in excludes.split(","):
|
||||
# src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
||||
#
|
||||
# dst_state = obj.state_dict()
|
||||
# src_state = assigment_scope_map(dst_state, src_state, scope_map)
|
||||
#
|
||||
# if ignore_init_mismatch:
|
||||
# src_state = filter_state_dict(dst_state, src_state)
|
||||
#
|
||||
# logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
# logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
||||
# dst_state.update(src_state)
|
||||
# obj.load_state_dict(dst_state, strict=True)
|
||||
|
||||
@ -128,7 +128,20 @@ class Trainer:
|
||||
if os.path.isfile(ckpt):
|
||||
checkpoint = torch.load(ckpt)
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
# self.model.load_state_dict(checkpoint['state_dict'])
|
||||
src_state = checkpoint['state_dict']
|
||||
dst_state = self.model.state_dict()
|
||||
for k in dst_state.keys():
|
||||
if not k.startswith("module.") and "module."+k in src_state.keys():
|
||||
k_ddp = "module."+k
|
||||
else:
|
||||
k_ddp = k
|
||||
if k_ddp in src_state.keys():
|
||||
dst_state[k] = src_state[k_ddp]
|
||||
else:
|
||||
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
||||
|
||||
self.model.load_state_dict(dst_state)
|
||||
self.optim.load_state_dict(checkpoint['optimizer'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user