mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fp16
This commit is contained in:
parent
23008c7cac
commit
63e60cc43d
@ -233,6 +233,8 @@ class AutoModel:
|
|||||||
# fp16
|
# fp16
|
||||||
if kwargs.get("fp16", False):
|
if kwargs.get("fp16", False):
|
||||||
model.to(torch.float16)
|
model.to(torch.float16)
|
||||||
|
elif kwargs.get("bf16", False):
|
||||||
|
model.to(torch.bfloat16)
|
||||||
return model, kwargs
|
return model, kwargs
|
||||||
|
|
||||||
def __call__(self, *args, **cfg):
|
def __call__(self, *args, **cfg):
|
||||||
|
|||||||
@ -684,6 +684,13 @@ class LLMASR2(nn.Module):
|
|||||||
# audio encoder
|
# audio encoder
|
||||||
speech = batch["speech"]
|
speech = batch["speech"]
|
||||||
speech_lengths = batch["speech_lengths"][:, 0]
|
speech_lengths = batch["speech_lengths"][:, 0]
|
||||||
|
# fp16
|
||||||
|
if kwargs.get("fp16", False):
|
||||||
|
speech = speech.to(torch.float16)
|
||||||
|
encoder_out_lens = encoder_out_lens.to(torch.float16)
|
||||||
|
elif kwargs.get("bf16", False):
|
||||||
|
speech = speech.to(torch.bfloat16)
|
||||||
|
encoder_out_lens = encoder_out_lens.to(torch.bfloat16)
|
||||||
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
||||||
|
|
||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user