mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix device bug
This commit is contained in:
parent
4f224c8806
commit
c47ad73e46
@ -134,8 +134,6 @@ class AutoModel:
|
||||
self.spk_model = spk_model
|
||||
self.spk_kwargs = spk_kwargs
|
||||
self.model_path = kwargs.get("model_path")
|
||||
|
||||
|
||||
|
||||
def build_model(self, **kwargs):
|
||||
assert "model" in kwargs
|
||||
@ -146,7 +144,7 @@ class AutoModel:
|
||||
set_all_random_seed(kwargs.get("seed", 0))
|
||||
|
||||
device = kwargs.get("device", "cuda")
|
||||
if not torch.cuda.is_available() or kwargs.get("ngpu", 0) == 0:
|
||||
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
||||
device = "cpu"
|
||||
kwargs["batch_size"] = 1
|
||||
kwargs["device"] = device
|
||||
@ -200,8 +198,6 @@ class AutoModel:
|
||||
res = self.model(*args, kwargs)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def generate(self, input, input_len=None, **cfg):
|
||||
if self.vad_model is None:
|
||||
return self.inference(input, input_len=input_len, **cfg)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user