mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
9afcf0ea7d
commit
2518f03d20
@ -427,30 +427,36 @@ class LLMASR2(nn.Module):
|
|||||||
self.audio_encoder = audio_encoder
|
self.audio_encoder = audio_encoder
|
||||||
|
|
||||||
# llm
|
# llm
|
||||||
hub = llm_conf.get("hub", "hf")
|
|
||||||
self.llm = None
|
self.llm = None
|
||||||
if hub == "hf":
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
|
||||||
|
|
||||||
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
|
||||||
init_param_path,
|
|
||||||
load_in_8bit=None,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
device_map=None,
|
init_param_path,
|
||||||
use_cache=None,
|
load_in_8bit=None,
|
||||||
)
|
device_map=None,
|
||||||
freeze = llm_conf.get("freeze", True)
|
use_cache=None,
|
||||||
if freeze:
|
)
|
||||||
for name, param in model.named_parameters():
|
freeze = llm_conf.get("freeze", True)
|
||||||
param.requires_grad = False
|
if freeze:
|
||||||
model.eval()
|
for name, param in model.named_parameters():
|
||||||
self.llm = model
|
param.requires_grad = False
|
||||||
|
model.eval()
|
||||||
|
self.llm = model
|
||||||
|
llm_dim = model.get_input_embeddings().weight.shape[-1]
|
||||||
|
|
||||||
# adaptor
|
# adaptor
|
||||||
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
|
||||||
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
|
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
|
||||||
|
audio_adaptor_conf["llm_dim"] = llm_dim
|
||||||
audio_adaptor = adaptor_class(**audio_adaptor_conf)
|
audio_adaptor = adaptor_class(**audio_adaptor_conf)
|
||||||
|
init_param_path = audio_adaptor_conf.get("init_param_path", None)
|
||||||
|
if init_param_path is not None:
|
||||||
|
src_state = torch.load(init_param_path, map_location="cpu")
|
||||||
|
flag = audio_adaptor.load_state_dict(src_state, strict=False)
|
||||||
|
logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
|
||||||
|
|
||||||
self.audio_adaptor = audio_adaptor
|
self.audio_adaptor = audio_adaptor
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user