This commit is contained in:
游雁 2024-06-12 17:44:12 +08:00
parent 9afcf0ea7d
commit 2518f03d20

View File

@ -427,30 +427,36 @@ class LLMASR2(nn.Module):
self.audio_encoder = audio_encoder self.audio_encoder = audio_encoder
# llm # llm
hub = llm_conf.get("hub", "hf")
self.llm = None self.llm = None
if hub == "hf":
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
model = AutoModelForCausalLM.from_pretrained( init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
init_param_path,
load_in_8bit=None, model = AutoModelForCausalLM.from_pretrained(
device_map=None, init_param_path,
use_cache=None, load_in_8bit=None,
) device_map=None,
freeze = llm_conf.get("freeze", True) use_cache=None,
if freeze: )
for name, param in model.named_parameters(): freeze = llm_conf.get("freeze", True)
param.requires_grad = False if freeze:
model.eval() for name, param in model.named_parameters():
self.llm = model param.requires_grad = False
model.eval()
self.llm = model
llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor # adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor) adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor_conf["llm_dim"] = llm_dim
audio_adaptor = adaptor_class(**audio_adaptor_conf) audio_adaptor = adaptor_class(**audio_adaptor_conf)
init_param_path = audio_adaptor_conf.get("init_param_path", None)
if init_param_path is not None:
src_state = torch.load(init_param_path, map_location="cpu")
flag = audio_adaptor.load_state_dict(src_state, strict=False)
logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
self.audio_adaptor = audio_adaptor self.audio_adaptor = audio_adaptor