mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
deepspeed
This commit is contained in:
parent
47fbbb8fdc
commit
66411e2b5b
@ -226,71 +226,75 @@ class Trainer:
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
|
||||
logging.info(f"\nCheckpoint saved to {filename}\n")
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
# torch.save(state, latest)
|
||||
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"ds-model.pt", client_state=state
|
||||
)
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
if not (step is None and epoch != 0):
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"ds-model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"ds-model.pt.best", client_state=state
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"ds-model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir,
|
||||
tag=f"ds-model.pt.best",
|
||||
client_state=state,
|
||||
)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"ds-model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"ds-model.pt.best", client_state=state
|
||||
)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"ds-model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir,
|
||||
tag=f"ds-model.pt.best",
|
||||
client_state=state,
|
||||
)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
elif self.use_fsdp:
|
||||
pass
|
||||
@ -359,57 +363,58 @@ class Trainer:
|
||||
logging.info(f"\nCheckpoint saved to {filename}\n")
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
torch.save(state, latest)
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
if not (step is None and epoch != 0):
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user