This commit is contained in:
游雁 2024-03-24 01:27:08 +08:00
parent 5d74e107fc
commit 16a976a01d
2 changed files with 42 additions and 41 deletions

View File

@ -128,7 +128,8 @@ def main(**kwargs):
else:
model = model.to(device=kwargs.get("device", "cuda"))
logging.info(f"{model}")
if local_rank == 0:
logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim

View File

@ -239,6 +239,8 @@ class Trainer:
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
model.train()
@ -248,15 +250,14 @@ class Trainer:
optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dist.barrier()
print(f"before iter, iterator_stop: {iterator_stop}\n")
# iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_train):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
break
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# if iterator_stop > 0:
# break
self.batch_total += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
@ -297,13 +298,13 @@ class Trainer:
self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
if "acc" in stats:
self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
# if self.use_ddp or self.use_fsdp:
# train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
# train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
# dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
# dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
# self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
# self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
if self.use_ddp or self.use_fsdp:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
# Perform an optimizer step only after accumulating enough gradients
@ -363,10 +364,10 @@ class Trainer:
if (batch_idx+1) % self.save_checkpoint_interval == 0:
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
else:
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# else:
# if self.use_ddp or self.use_fsdp:
# iterator_stop.fill_(1)
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if self.use_ddp or self.use_fsdp:
dist.barrier()
@ -387,6 +388,8 @@ class Trainer:
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
model.eval()
@ -394,16 +397,15 @@ class Trainer:
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dist.barrier()
print(f"before iter, iterator_stop: {iterator_stop}\n")
# iterator_stop = torch.tensor(0).to(self.device)
for batch_idx, batch in enumerate(dataloader_val):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if epoch >= 1:
print(f"iterator_stop: {iterator_stop}\n")
if iterator_stop > 0:
break
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# if epoch >= 1:
# print(f"iterator_stop: {iterator_stop}\n")
# if iterator_stop > 0:
# break
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
@ -432,13 +434,13 @@ class Trainer:
self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
if "acc" in stats:
self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
# if self.use_ddp or self.use_fsdp:
# val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
# val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
# dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
# dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
# self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
# self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
if self.use_ddp or self.use_fsdp:
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
batch_num_epoch = -1
if hasattr(dataloader_val, "__len__"):
@ -453,16 +455,14 @@ class Trainer:
tag="val",
)
else:
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# else:
# if self.use_ddp or self.use_fsdp:
# iterator_stop.fill_(1)
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
self.val_acc_list.append(self.val_acc_avg)
model.train()
if self.use_ddp or self.use_fsdp:
dist.barrier()