mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
c553a8db17
commit
7355e20503
@ -496,11 +496,14 @@ class LLMASR2(nn.Module):
|
|||||||
|
|
||||||
batch_size, frames, _ = speech.shape
|
batch_size, frames, _ = speech.shape
|
||||||
|
|
||||||
# audio encoder
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
# audio encoder
|
||||||
|
encoder_out, encoder_out_lens = self.audio_encoder(
|
||||||
|
speech.permute(0, 2, 1), speech_lengths
|
||||||
|
)
|
||||||
|
|
||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||||
|
|
||||||
input_ids[input_ids < 0] = 0
|
input_ids[input_ids < 0] = 0
|
||||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user