This commit is contained in:
shixian.shi 2023-06-27 20:03:13 +08:00
parent 007a8dc289
commit 6569b95025

View File

@ -1150,7 +1150,6 @@ class AbsTask(ABC):
def main_worker(cls, args: argparse.Namespace):
assert check_argument_types()
args.ngpu = 0
# 0. Init distributed process
distributed_option = build_dataclass(DistributedOption, args)
# Setting distributed_option.dist_rank, etc.
@ -1253,13 +1252,9 @@ class AbsTask(ABC):
raise RuntimeError(
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
#model = model.to(
# dtype=getattr(torch, args.train_dtype),
# device="cuda" if args.ngpu > 0 else "cpu",
#)
model = model.to(
dtype=getattr(torch, args.train_dtype),
device="cpu",
device="cuda" if args.ngpu > 0 else "cpu",
)
for t in args.freeze_param:
for k, p in model.named_parameters():