mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
ds
This commit is contained in:
parent
1e1500adad
commit
b3b1015809
@ -574,12 +574,12 @@ class Trainer:
|
||||
loss_dict["lr"] = scheduler.get_last_lr()[0]
|
||||
loss_dict["batch_num_epoch"] = len(dataloader_train)
|
||||
|
||||
self.val_loss_avg = (
|
||||
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
|
||||
self.train_loss_avg = (
|
||||
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
if "acc" in loss_dict["stats"]:
|
||||
self.val_acc_avg = (
|
||||
self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
self.train_acc_avg = (
|
||||
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
|
||||
self.log(loss_dict, tag="train")
|
||||
@ -612,12 +612,12 @@ class Trainer:
|
||||
time_beg = time.perf_counter()
|
||||
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
|
||||
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
|
||||
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
|
||||
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
|
||||
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
|
||||
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
def forward_step(self, model, batch, loss_dict={}):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
Loading…
Reference in New Issue
Block a user