mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fp16
This commit is contained in:
parent
23008c7cac
commit
63e60cc43d
@ -233,6 +233,8 @@ class AutoModel:
|
||||
# fp16
|
||||
if kwargs.get("fp16", False):
|
||||
model.to(torch.float16)
|
||||
elif kwargs.get("bf16", False):
|
||||
model.to(torch.bfloat16)
|
||||
return model, kwargs
|
||||
|
||||
def __call__(self, *args, **cfg):
|
||||
|
||||
@ -684,6 +684,13 @@ class LLMASR2(nn.Module):
|
||||
# audio encoder
|
||||
speech = batch["speech"]
|
||||
speech_lengths = batch["speech_lengths"][:, 0]
|
||||
# fp16
|
||||
if kwargs.get("fp16", False):
|
||||
speech = speech.to(torch.float16)
|
||||
encoder_out_lens = encoder_out_lens.to(torch.float16)
|
||||
elif kwargs.get("bf16", False):
|
||||
speech = speech.to(torch.bfloat16)
|
||||
encoder_out_lens = encoder_out_lens.to(torch.bfloat16)
|
||||
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
|
||||
Loading…
Reference in New Issue
Block a user