add mac m1 mps support (#2477)

This commit is contained in:
xmx0632 2025-04-14 00:40:12 -05:00 committed by GitHub
parent 8c336fed79
commit e7237d8cb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 3 deletions

View File

@ -184,6 +184,7 @@ class AutoModel:
device = kwargs.get("device", "cuda")
if ((device =="cuda" and not torch.cuda.is_available())
or (device == "xpu" and not torch.xpu.is_available())
or (device == "mps" and not torch.backends.mps.is_available())
or kwargs.get("ngpu", 1) == 0):
device = "cpu"
kwargs["batch_size"] = 1

View File

@ -80,6 +80,8 @@ class FusedFrontends(nn.Module):
dev = "cuda"
elif torch.xpu.is_available():
dev = "xpu"
elif torch.backends.mps.is_available():
dev = "mps"
else:
dev = "cpu"
if self.align_method == "linear_projection":

View File

@ -28,12 +28,12 @@ def export(
**kwargs,
)
elif type == "torchscript":
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("Exporting torchscripts on device {}".format(device))
_torchscripts(m, path=export_dir, device=device)
elif type == "bladedisc":
assert (
torch.cuda.is_available() or torch.xpu.is_available()
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
), "Currently bladedisc optimization for FunASR only supports GPU"
# bladedisc only optimizes encoder/decoder modules
if hasattr(m, "encoder") and hasattr(m, "decoder"):
@ -44,7 +44,7 @@ def export(
elif type == "onnx_fp16":
assert (
torch.cuda.is_available() or torch.xpu.is_available()
torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
if hasattr(m, "encoder") and hasattr(m, "decoder"):