mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
496ca8eddb
commit
8f6d2787f0
@ -2744,7 +2744,7 @@ class LLMASR5(nn.Module):
|
||||
] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
|
||||
|
||||
token_num = len(hidden_states)
|
||||
hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
|
||||
hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
hidden_states_out_len = torch.tensor(
|
||||
@ -2754,19 +2754,23 @@ class LLMASR5(nn.Module):
|
||||
dtype=torch.int32,
|
||||
).to(inputs_embeds.device)
|
||||
for i in range(token_num):
|
||||
hidden_states_out[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
|
||||
hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
|
||||
|
||||
speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[
|
||||
target_ids = generated_ids["sequences"]
|
||||
target_emb = self.llm.model.get_input_embeddings()(target_ids)
|
||||
if self.concat_emb_hidden:
|
||||
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
|
||||
|
||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||
:, :, 0
|
||||
] # 1xlx1: 2,10,1023
|
||||
|
||||
sequences = generated_ids["sequences"]
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
# for input_id, output_ids in zip(input_ids, generated_ids)
|
||||
# ]
|
||||
response = tokenizer.batch_decode(
|
||||
sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
target_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
|
||||
loss = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user