This commit is contained in:
游雁 2024-05-20 11:41:53 +08:00
parent 1e1500adad
commit b3b1015809

View File

@ -574,12 +574,12 @@ class Trainer:
loss_dict["lr"] = scheduler.get_last_lr()[0]
loss_dict["batch_num_epoch"] = len(dataloader_train)
self.val_loss_avg = (
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
self.train_loss_avg = (
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
if "acc" in loss_dict["stats"]:
self.val_acc_avg = (
self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
self.train_acc_avg = (
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
self.log(loss_dict, tag="train")
@ -612,12 +612,12 @@ class Trainer:
time_beg = time.perf_counter()
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
def forward_step(self, model, batch, loss_dict={}):
dtype = torch.bfloat16