deepspeed

This commit is contained in:
游雁 2024-08-06 00:28:41 +08:00
parent ec057eff64
commit 1f46a4aefc

View File

@ -229,7 +229,9 @@ class Trainer:
latest = Path(os.path.join(self.output_dir, f"model.pt"))
# torch.save(state, latest)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=f"model.pt", client_state=state)
model.save_checkpoint(
save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state
)
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
@ -429,9 +431,9 @@ class Trainer:
if self.resume:
if self.use_deepspeed:
ckpt = os.path.join(self.output_dir, "model.pt")
ckpt = os.path.join(self.output_dir, "ds-model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
_, checkpoint = model.load_checkpoint(ckpt)
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (