mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
256defef10
commit
e969be589e
@ -2117,11 +2117,15 @@ class LLMASR5(nn.Module):
|
||||
self.eos = kwargs.get("eos", 151645)
|
||||
|
||||
# audio decoder related
|
||||
self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False)
|
||||
self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024)
|
||||
self.codebook_size = audio_decoder_conf.get("codebook_size", 4096)
|
||||
self.lm_out_voc_size = self.codebook_size + 1
|
||||
self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
|
||||
self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit)
|
||||
audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim
|
||||
self.audio_decoder_in_proj = torch.nn.Linear(
|
||||
audio_decoder_in_proj_dim, self.audio_decoder.embed_unit
|
||||
)
|
||||
self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim)
|
||||
self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit)
|
||||
self.ad_sos_eos = 0
|
||||
@ -2395,7 +2399,11 @@ class LLMASR5(nn.Module):
|
||||
)
|
||||
target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device)
|
||||
target_ids = target_ids.to(device=input_ids.device)
|
||||
target_ids[target_ids < 0] = 0
|
||||
target_emb = self.llm.model.get_input_embeddings()(target_ids)
|
||||
hidden_states_select = hidden_states_select.to(device=input_ids.device)
|
||||
if self.concat_emb_hidden:
|
||||
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
|
||||
nll, logits, target, target_lengths = self.nll(
|
||||
hidden_states_select, target_ids_len, codec[:, :, None], codec_len
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user