This commit is contained in:
游雁 2024-07-05 20:53:32 +08:00
parent 496ca8eddb
commit 8f6d2787f0

View File

@ -2744,7 +2744,7 @@ class LLMASR5(nn.Module):
] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
token_num = len(hidden_states)
hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
inputs_embeds.device
)
hidden_states_out_len = torch.tensor(
@ -2754,19 +2754,23 @@ class LLMASR5(nn.Module):
dtype=torch.int32,
).to(inputs_embeds.device)
for i in range(token_num):
hidden_states_out[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
speech_tokens = self.audio_decode(hidden_states_out, hidden_states_out_len)[
target_ids = generated_ids["sequences"]
target_emb = self.llm.model.get_input_embeddings()(target_ids)
if self.concat_emb_hidden:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
:, :, 0
] # 1xlx1: 2,10,1023
sequences = generated_ids["sequences"]
# generated_ids = [
# output_ids[len(input_id) :]
# for input_id, output_ids in zip(input_ids, generated_ids)
# ]
response = tokenizer.batch_decode(
sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True)
target_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0]
loss = None