mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
deepspeed
This commit is contained in:
parent
ec057eff64
commit
1f46a4aefc
@ -229,7 +229,9 @@ class Trainer:
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
# torch.save(state, latest)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=f"model.pt", client_state=state)
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state
|
||||
)
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
@ -429,9 +431,9 @@ class Trainer:
|
||||
if self.resume:
|
||||
|
||||
if self.use_deepspeed:
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
ckpt = os.path.join(self.output_dir, "ds-model.pt")
|
||||
if os.path.exists(ckpt):
|
||||
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
|
||||
_, checkpoint = model.load_checkpoint(ckpt)
|
||||
self.start_epoch = checkpoint["epoch"]
|
||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||
self.val_acc_step_or_eoch = (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user