diff --git a/examples/industrial_data_pretraining/lcbnet/demo2.sh b/examples/industrial_data_pretraining/lcbnet/demo2.sh index 36a692856..3e4d22393 100755 --- a/examples/industrial_data_pretraining/lcbnet/demo2.sh +++ b/examples/industrial_data_pretraining/lcbnet/demo2.sh @@ -7,8 +7,7 @@ python -m funasr.bin.inference \ ++init_param=${file_dir}/model.pb \ ++tokenizer_conf.token_list=${file_dir}/tokens.txt \ ++frontend_conf.cmvn_file=${file_dir}/am.mvn \ -++input=${file_dir}/wav.scp \ -++input=${file_dir}/ocr_text \ +++input=[${file_dir}/wav.scp,${file_dir}/ocr_text] \ +data_type='["sound", "text"]' \ ++tokenizer_conf.bpemodel=${file_dir}/bpe.model \ ++output_dir="./outputs/debug" \ diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 23b80d72a..87c7e2d03 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -172,14 +172,11 @@ class AutoModel: # build model model_class = tables.model_classes.get(kwargs["model"]) - pdb.set_trace() model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) - pdb.set_trace() model.to(device) # init_param init_param = kwargs.get("init_param", None) - pdb.set_trace() if init_param is not None: logging.info(f"Loading pretrained params from {init_param}") load_pretrained_model( diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index aec31e3cc..9127e2fe1 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -96,19 +96,17 @@ def load_pretrained_model( obj = model dst_state = obj.state_dict() - # import pdb; - # pdb.set_trace() print(f"ckpt: {path}") - pdb.set_trace() + if oss_bucket is None: src_state = torch.load(path, map_location=map_location) else: buffer = BytesIO(oss_bucket.get_object(path).read()) src_state = torch.load(buffer, map_location=map_location) - pdb.set_trace() + if "state_dict" in src_state: src_state = src_state["state_dict"] - pdb.set_trace() + for k in dst_state.keys(): if not k.startswith("module.") and "module." + k in src_state.keys(): k_ddp = "module." + k @@ -118,7 +116,6 @@ def load_pretrained_model( dst_state[k] = src_state[k_ddp] else: print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") - pdb.set_trace() flag = obj.load_state_dict(dst_state, strict=True) # print(flag)