This commit is contained in:
嘉渊 2023-04-25 14:11:39 +08:00
parent ec5e15d47c
commit 74a2059e9c

View File

@ -520,6 +520,10 @@ if __name__ == '__main__':
prepare_data(args, distributed_option)
model = build_model(args)
model = model.to(
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
optimizers = build_optimizer(args, model=model)
schedulers = build_scheduler(args, optimizers)