step_or_epoch bugfix

This commit is contained in:
游雁 2025-01-10 10:16:11 +08:00
parent e6fe602db3
commit d4f13c2e44

View File

@ -161,8 +161,8 @@ class Trainer:
# self.step_or_epoch += 1
state = {
"epoch": epoch,
'step': step,
'total_step': self.batch_total,
"step": step,
"total_step": self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
@ -171,7 +171,6 @@ class Trainer:
"val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
@ -194,9 +193,9 @@ class Trainer:
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
logging.info(f'Checkpoint saved to {filename}')
logging.info(f"Checkpoint saved to {filename}")
latest = Path(os.path.join(self.output_dir, f'model.pt'))
latest = Path(os.path.join(self.output_dir, f"model.pt"))
torch.save(state, latest)
if self.best_step_or_epoch == "":
@ -333,7 +332,6 @@ class Trainer:
if self.use_ddp or self.use_fsdp:
dist.barrier()
def train_epoch(
self,
model=None,
@ -591,9 +589,9 @@ class Trainer:
time4 = time.perf_counter()
if torch.isfinite(loss):
self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
batch_idx + 1
)
self.val_loss_avg = (
self.val_loss_avg * batch_idx + loss.detach().cpu().item()
) / (batch_idx + 1)
if "acc" in stats:
self.val_acc_avg = (