mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
step_or_epoch bugfix
This commit is contained in:
parent
e6fe602db3
commit
d4f13c2e44
@ -161,8 +161,8 @@ class Trainer:
|
||||
# self.step_or_epoch += 1
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
'step': step,
|
||||
'total_step': self.batch_total,
|
||||
"step": step,
|
||||
"total_step": self.batch_total,
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optim.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
@ -171,7 +171,6 @@ class Trainer:
|
||||
"val_loss_step_or_epoch": self.val_loss_step_or_epoch,
|
||||
"best_step_or_epoch": self.best_step_or_epoch,
|
||||
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
|
||||
"step": step,
|
||||
"step_in_epoch": step_in_epoch,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
@ -194,9 +193,9 @@ class Trainer:
|
||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
torch.save(state, filename)
|
||||
logging.info(f'Checkpoint saved to {filename}')
|
||||
logging.info(f"Checkpoint saved to {filename}")
|
||||
|
||||
latest = Path(os.path.join(self.output_dir, f'model.pt'))
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
torch.save(state, latest)
|
||||
|
||||
if self.best_step_or_epoch == "":
|
||||
@ -333,7 +332,6 @@ class Trainer:
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
model=None,
|
||||
@ -591,9 +589,9 @@ class Trainer:
|
||||
time4 = time.perf_counter()
|
||||
|
||||
if torch.isfinite(loss):
|
||||
self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
|
||||
batch_idx + 1
|
||||
)
|
||||
self.val_loss_avg = (
|
||||
self.val_loss_avg * batch_idx + loss.detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
|
||||
if "acc" in stats:
|
||||
self.val_acc_avg = (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user