diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 0d848a97d..0fb77a93e 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1150,7 +1150,6 @@ class AbsTask(ABC): def main_worker(cls, args: argparse.Namespace): assert check_argument_types() - args.ngpu = 0 # 0. Init distributed process distributed_option = build_dataclass(DistributedOption, args) # Setting distributed_option.dist_rank, etc. @@ -1253,13 +1252,9 @@ class AbsTask(ABC): raise RuntimeError( f"model must inherit {FunASRModel.__name__}, but got {type(model)}" ) - #model = model.to( - # dtype=getattr(torch, args.train_dtype), - # device="cuda" if args.ngpu > 0 else "cpu", - #) model = model.to( dtype=getattr(torch, args.train_dtype), - device="cpu", + device="cuda" if args.ngpu > 0 else "cpu", ) for t in args.freeze_param: for k, p in model.named_parameters():