This commit is contained in:
游雁 2024-07-04 23:30:30 +08:00
parent 256defef10
commit e969be589e

View File

@ -2117,11 +2117,15 @@ class LLMASR5(nn.Module):
self.eos = kwargs.get("eos", 151645)
# audio decoder related
self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False)
self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024)
self.codebook_size = audio_decoder_conf.get("codebook_size", 4096)
self.lm_out_voc_size = self.codebook_size + 1
self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit)
audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim
self.audio_decoder_in_proj = torch.nn.Linear(
audio_decoder_in_proj_dim, self.audio_decoder.embed_unit
)
self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim)
self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit)
self.ad_sos_eos = 0
@ -2395,7 +2399,11 @@ class LLMASR5(nn.Module):
)
target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device)
target_ids = target_ids.to(device=input_ids.device)
target_ids[target_ids < 0] = 0
target_emb = self.llm.model.get_input_embeddings()(target_ids)
hidden_states_select = hidden_states_select.to(device=input_ids.device)
if self.concat_emb_hidden:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
nll, logits, target, target_lengths = self.nll(
hidden_states_select, target_ids_len, codec[:, :, None], codec_len
)