deepspeed

This commit is contained in:
游雁 2024-08-06 01:26:20 +08:00
parent a581e50e30
commit 4121b8907c

View File

@ -439,7 +439,7 @@ class Trainer:
if self.use_deepspeed:
ckpt = os.path.join(self.output_dir, "ds-model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(ckpt)
_, checkpoint = model.load_checkpoint(self.output_dir, "ds-model.pt")
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (