This commit is contained in:
游雁 2024-07-04 09:43:29 +08:00
parent 1bf66d044f
commit e8fe5711a2

View File

@ -2722,9 +2722,17 @@ class LLMASR5(nn.Module):
return_dict_in_generate=True,
output_scores=True,
)
hidden_states = generated_ids["hidden_states"]
hidden_states = generated_ids[
"hidden_states"
] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
# hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
token_num = len(hidden_states)
hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
inputs_embeds.device
)
for i in range(token_num):
hidden_states_out[0, i, :] = hidden_states[1, -1][0, 0, :].to(torch.float32)
# generated_ids = [
# output_ids[len(input_id) :]