mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
0798219669
commit
7fe1b3c0e4
@ -2205,7 +2205,16 @@ class LLMASR5(nn.Module):
|
||||
target_ids = generated_ids["sequences"]
|
||||
target_emb = self.llm.model.get_input_embeddings()(target_ids)
|
||||
if self.concat_emb_hidden:
|
||||
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
|
||||
if not self.concat_emb_hidden_norm:
|
||||
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
|
||||
hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
|
||||
else:
|
||||
outs = self.hidden_norm(hidden_states_select)
|
||||
outs = self.fusion_dropout(self.fusion_act(outs))
|
||||
# emb = model_outputs.hidden_states[0]
|
||||
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
|
||||
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
|
||||
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
||||
|
||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||
:, :, 0
|
||||
@ -2263,7 +2272,7 @@ class LLMASR5(nn.Module):
|
||||
decoding_length=None,
|
||||
):
|
||||
# 1. encode text
|
||||
text = self.audio_decoder_in_proj(text)
|
||||
# text = self.audio_decoder_in_proj(text)
|
||||
device = text.device
|
||||
out_tokens = []
|
||||
sos_eos_emb = self.audio_decoder_embedding(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user