diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 2a55cd6bb..fb0bee3ca 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -427,30 +427,36 @@ class LLMASR2(nn.Module): self.audio_encoder = audio_encoder # llm - hub = llm_conf.get("hub", "hf") self.llm = None - if hub == "hf": - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - model = AutoModelForCausalLM.from_pretrained( - init_param_path, - load_in_8bit=None, - device_map=None, - use_cache=None, - ) - freeze = llm_conf.get("freeze", True) - if freeze: - for name, param in model.named_parameters(): - param.requires_grad = False - model.eval() - self.llm = model + init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + + model = AutoModelForCausalLM.from_pretrained( + init_param_path, + load_in_8bit=None, + device_map=None, + use_cache=None, + ) + freeze = llm_conf.get("freeze", True) + if freeze: + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + self.llm = model + llm_dim = model.get_input_embeddings().weight.shape[-1] # adaptor adaptor_class = tables.adaptor_classes.get(audio_adaptor) audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size + audio_adaptor_conf["llm_dim"] = llm_dim audio_adaptor = adaptor_class(**audio_adaptor_conf) + init_param_path = audio_adaptor_conf.get("init_param_path", None) + if init_param_path is not None: + src_state = torch.load(init_param_path, map_location="cpu") + flag = audio_adaptor.load_state_dict(src_state, strict=False) + logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") self.audio_adaptor = audio_adaptor