mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
total_time/accum_grad
This commit is contained in:
parent
64bf6dd8a1
commit
e299cfecaf
@ -384,18 +384,19 @@ class Trainer:
|
||||
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
# stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight.sum() # shape:[1] -> shape:[]
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# # Apply weighted averaging for loss and stats
|
||||
# loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# # if distributed, this method can also apply all_reduce()
|
||||
# # stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
||||
# # Now weight is summation over all workers
|
||||
# loss /= weight.sum() # shape:[1] -> shape:[]
|
||||
# # Multiply world_size because DistributedDataParallel
|
||||
# # automatically normalizes the gradient by world_size.
|
||||
# loss *= self.world_size
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
|
||||
@ -416,17 +417,6 @@ class Trainer:
|
||||
self.train_acc_avg * (self.step_in_epoch - 1)
|
||||
+ stats["acc"].detach().cpu().item()
|
||||
) / self.step_in_epoch
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accum_grad == 0:
|
||||
@ -457,6 +447,19 @@ class Trainer:
|
||||
optim.zero_grad(set_to_none=True)
|
||||
total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
speed_stats["total_time"] = total_time
|
||||
|
||||
Loading…
Reference in New Issue
Block a user