This commit is contained in:
游雁 2024-08-18 14:04:02 +08:00
parent 682b14d8d5
commit d0e4e2ad21

View File

@ -136,7 +136,7 @@ def main(**kwargs):
**kwargs.get("train_conf"),
)
model = trainer.warp_model(model)
model = trainer.warp_model(model, **kwargs)
kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
trainer.device = int(os.environ.get("LOCAL_RANK", 0))