From e969be589e6270d69906fca252609aec8530321c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 23:30:30 +0800 Subject: [PATCH] update --- funasr/models/llm_asr/model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 88eb8c001..a6a05ca0b 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2117,11 +2117,15 @@ class LLMASR5(nn.Module): self.eos = kwargs.get("eos", 151645) # audio decoder related + self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False) self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024) self.codebook_size = audio_decoder_conf.get("codebook_size", 4096) self.lm_out_voc_size = self.codebook_size + 1 self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) - self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) + audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim + self.audio_decoder_in_proj = torch.nn.Linear( + audio_decoder_in_proj_dim, self.audio_decoder.embed_unit + ) self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim) self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit) self.ad_sos_eos = 0 @@ -2395,7 +2399,11 @@ class LLMASR5(nn.Module): ) target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device) target_ids = target_ids.to(device=input_ids.device) + target_ids[target_ids < 0] = 0 + target_emb = self.llm.model.get_input_embeddings()(target_ids) hidden_states_select = hidden_states_select.to(device=input_ids.device) + if self.concat_emb_hidden: + hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1) nll, logits, target, target_lengths = self.nll( hidden_states_select, target_ids_len, codec[:, :, None], codec_len )