This commit is contained in:
游雁 2024-03-24 01:44:18 +08:00
parent 16a976a01d
commit a70f5b3edf
2 changed files with 12 additions and 11 deletions

View File

@ -173,10 +173,10 @@ def main(**kwargs):
except:
writer = None
# if use_ddp or use_fsdp:
# context = Join([model])
# else:
context = nullcontext()
if use_ddp or use_fsdp:
context = Join([model])
else:
context = nullcontext()
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
@ -192,13 +192,14 @@ def main(**kwargs):
epoch=epoch,
writer=writer
)
with context:
trainer.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer
)
scheduler.step()
trainer.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer
)
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)

View File

@ -398,7 +398,7 @@ class Trainer:
speed_stats = {}
time5 = time.perf_counter()
# iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)