This commit is contained in:
游雁 2024-07-10 10:10:38 +08:00
parent 0798219669
commit 7fe1b3c0e4

View File

@ -2205,7 +2205,16 @@ class LLMASR5(nn.Module):
target_ids = generated_ids["sequences"]
target_emb = self.llm.model.get_input_embeddings()(target_ids)
if self.concat_emb_hidden:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
if not self.concat_emb_hidden_norm:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
else:
outs = self.hidden_norm(hidden_states_select)
outs = self.fusion_dropout(self.fusion_act(outs))
# emb = model_outputs.hidden_states[0]
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
:, :, 0
@ -2263,7 +2272,7 @@ class LLMASR5(nn.Module):
decoding_length=None,
):
# 1. encode text
text = self.audio_decoder_in_proj(text)
# text = self.audio_decoder_in_proj(text)
device = text.device
out_tokens = []
sos_eos_emb = self.audio_decoder_embedding(