deepspeed

This commit is contained in:
游雁 2024-05-17 12:38:24 +08:00
parent d3ff05837b
commit 86ada491e0

View File

@ -130,8 +130,8 @@ def main(**kwargs):
model = trainer.warp_model(model)
kwargs["device"] = next(model.parameters()).device
trainer.device = kwargs["device"]
kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
trainer.device = int(os.environ.get("LOCAL_RANK", 0))
model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs)