From 497edf4c9d6c1565a4bcf1a3edfcd47ffec8c10d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 21 Feb 2024 11:30:59 +0800 Subject: [PATCH] bugfix --- examples/aishell/conformer/infer.sh | 12 ++++ funasr/train_utils/load_pretrained_model.py | 79 ++++++++++++++++----- funasr/train_utils/trainer.py | 15 +++- 3 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 examples/aishell/conformer/infer.sh diff --git a/examples/aishell/conformer/infer.sh b/examples/aishell/conformer/infer.sh new file mode 100644 index 000000000..a64df547e --- /dev/null +++ b/examples/aishell/conformer/infer.sh @@ -0,0 +1,12 @@ + + +python funasr/bin/inference.py \ +--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ +--config-name="config.yaml" \ +++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ +++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ +++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ +++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ +++output_dir="./outputs/debug" \ +++device="cpu" \ + diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index ceab4ee1e..ff96ebbfb 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -75,6 +75,7 @@ def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None): return assignment_map + def load_pretrained_model( path: str, model: torch.nn.Module, @@ -94,25 +95,69 @@ def load_pretrained_model( """ obj = model - + dst_state = obj.state_dict() + # import pdb; + # pdb.set_trace() + print(f"ckpt: {path}") if oss_bucket is None: src_state = torch.load(path, map_location=map_location) else: buffer = BytesIO(oss_bucket.get_object(path).read()) src_state = torch.load(buffer, map_location=map_location) - src_state = src_state["model"] if "model" in src_state else src_state - - if excludes is not None: - for e in excludes.split(","): - src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} - - dst_state = obj.state_dict() - src_state = assigment_scope_map(dst_state, src_state, scope_map) - - if ignore_init_mismatch: - src_state = filter_state_dict(dst_state, src_state) - - logging.debug("Loaded src_state keys: {}".format(src_state.keys())) - logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) - dst_state.update(src_state) - obj.load_state_dict(dst_state, strict=True) + if "state_dict" in src_state: + src_state = src_state["state_dict"] + + for k in dst_state.keys(): + if not k.startswith("module.") and "module." + k in src_state.keys(): + k_ddp = "module." + k + else: + k_ddp = k + if k_ddp in src_state: + dst_state[k] = src_state[k_ddp] + else: + print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") + + flag = obj.load_state_dict(dst_state, strict=True) + print(flag) + +# def load_pretrained_model( +# path: str, +# model: torch.nn.Module, +# ignore_init_mismatch: bool, +# map_location: str = "cpu", +# oss_bucket=None, +# scope_map=None, +# excludes=None, +# ): +# """Load a model state and set it to the model. +# +# Args: +# init_param: ::: +# +# Examples: +# +# """ +# +# obj = model +# +# if oss_bucket is None: +# src_state = torch.load(path, map_location=map_location) +# else: +# buffer = BytesIO(oss_bucket.get_object(path).read()) +# src_state = torch.load(buffer, map_location=map_location) +# src_state = src_state["model"] if "model" in src_state else src_state +# +# if excludes is not None: +# for e in excludes.split(","): +# src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} +# +# dst_state = obj.state_dict() +# src_state = assigment_scope_map(dst_state, src_state, scope_map) +# +# if ignore_init_mismatch: +# src_state = filter_state_dict(dst_state, src_state) +# +# logging.debug("Loaded src_state keys: {}".format(src_state.keys())) +# logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) +# dst_state.update(src_state) +# obj.load_state_dict(dst_state, strict=True) diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index b3c99539b..4b85a66a6 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -128,7 +128,20 @@ class Trainer: if os.path.isfile(ckpt): checkpoint = torch.load(ckpt) self.start_epoch = checkpoint['epoch'] + 1 - self.model.load_state_dict(checkpoint['state_dict']) + # self.model.load_state_dict(checkpoint['state_dict']) + src_state = checkpoint['state_dict'] + dst_state = self.model.state_dict() + for k in dst_state.keys(): + if not k.startswith("module.") and "module."+k in src_state.keys(): + k_ddp = "module."+k + else: + k_ddp = k + if k_ddp in src_state.keys(): + dst_state[k] = src_state[k_ddp] + else: + print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") + + self.model.load_state_dict(dst_state) self.optim.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) print(f"Checkpoint loaded successfully from '{ckpt}'")