mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add mac m1 mps support (#2477)
This commit is contained in:
parent
8c336fed79
commit
e7237d8cb4
@ -184,6 +184,7 @@ class AutoModel:
|
||||
device = kwargs.get("device", "cuda")
|
||||
if ((device =="cuda" and not torch.cuda.is_available())
|
||||
or (device == "xpu" and not torch.xpu.is_available())
|
||||
or (device == "mps" and not torch.backends.mps.is_available())
|
||||
or kwargs.get("ngpu", 1) == 0):
|
||||
device = "cpu"
|
||||
kwargs["batch_size"] = 1
|
||||
|
||||
@ -80,6 +80,8 @@ class FusedFrontends(nn.Module):
|
||||
dev = "cuda"
|
||||
elif torch.xpu.is_available():
|
||||
dev = "xpu"
|
||||
elif torch.backends.mps.is_available():
|
||||
dev = "mps"
|
||||
else:
|
||||
dev = "cpu"
|
||||
if self.align_method == "linear_projection":
|
||||
|
||||
@ -28,12 +28,12 @@ def export(
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "torchscript":
|
||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
print("Exporting torchscripts on device {}".format(device))
|
||||
_torchscripts(m, path=export_dir, device=device)
|
||||
elif type == "bladedisc":
|
||||
assert (
|
||||
torch.cuda.is_available() or torch.xpu.is_available()
|
||||
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
|
||||
), "Currently bladedisc optimization for FunASR only supports GPU"
|
||||
# bladedisc only optimizes encoder/decoder modules
|
||||
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
||||
@ -44,7 +44,7 @@ def export(
|
||||
|
||||
elif type == "onnx_fp16":
|
||||
assert (
|
||||
torch.cuda.is_available() or torch.xpu.is_available()
|
||||
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
|
||||
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
|
||||
|
||||
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user