diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 16feabd70..1705115c0 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -107,7 +107,7 @@ def load_pretrained_model( src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} dst_state = obj.state_dict() - src_state = assigment_scope_map(dst_state, src_state, scope_map) + dst_state = assigment_scope_map(dst_state, src_state, scope_map) if ignore_init_mismatch: src_state = filter_state_dict(dst_state, src_state) @@ -115,4 +115,4 @@ def load_pretrained_model( logging.debug("Loaded src_state keys: {}".format(src_state.keys())) logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) # dst_state.update(src_state) - obj.load_state_dict(dst_state) \ No newline at end of file + obj.load_state_dict(dst_state)