mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
1bf66d044f
commit
e8fe5711a2
@ -2722,9 +2722,17 @@ class LLMASR5(nn.Module):
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
hidden_states = generated_ids["hidden_states"]
|
||||
hidden_states = generated_ids[
|
||||
"hidden_states"
|
||||
] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
|
||||
|
||||
# hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
|
||||
token_num = len(hidden_states)
|
||||
hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
for i in range(token_num):
|
||||
hidden_states_out[0, i, :] = hidden_states[1, -1][0, 0, :].to(torch.float32)
|
||||
|
||||
# generated_ids = [
|
||||
# output_ids[len(input_id) :]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user