From ed22e34d654c47017962d3e5758d3a351d8826ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sun, 24 Mar 2024 15:03:54 +0800 Subject: [PATCH] finetune --- funasr/bin/train.py | 10 +++++----- funasr/train_utils/trainer.py | 35 +++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/funasr/bin/train.py b/funasr/bin/train.py index e446e5404..6cb486bf2 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -173,11 +173,11 @@ def main(**kwargs): except: writer = None - if use_ddp or use_fsdp: - context = Join([model]) - else: - context = nullcontext() - + # if use_ddp or use_fsdp: + # context = Join([model]) + # else: + # context = nullcontext() + context = nullcontext() for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): time1 = time.perf_counter() with context: diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index e554aca6d..c66539418 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -250,14 +250,14 @@ class Trainer: optim.zero_grad() speed_stats = {} time5 = time.perf_counter() - # iterator_stop = torch.tensor(0).to(self.device) + iterator_stop = torch.tensor(0).to(self.device) dataloader_train.batch_sampler.set_epoch(epoch) for batch_idx, batch in enumerate(dataloader_train): - # if self.use_ddp or self.use_fsdp: - # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) - # if iterator_stop > 0: - # break + if self.use_ddp or self.use_fsdp: + dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) + if iterator_stop > 0: + break self.batch_total += 1 time1 = time.perf_counter() speed_stats["data_load"] = f"{time1-time5:0.3f}" @@ -340,7 +340,7 @@ class Trainer: speed_stats["total_time"] = total_time lr = scheduler.get_last_lr()[0] - batch_num_epoch = -1 + batch_num_epoch = 1 if hasattr(dataloader_train, "__len__"): batch_num_epoch = len(dataloader_train) self.log(epoch, batch_idx, @@ -364,13 +364,15 @@ class Trainer: if (batch_idx+1) % self.save_checkpoint_interval == 0: self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1) - # else: - # if self.use_ddp or self.use_fsdp: - # iterator_stop.fill_(1) - # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) + else: + if self.use_ddp or self.use_fsdp: + iterator_stop.fill_(1) + dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) if self.use_ddp or self.use_fsdp: dist.barrier() + + iterator_stop = torch.tensor(0).to(self.device) @@ -397,7 +399,7 @@ class Trainer: speed_stats = {} time5 = time.perf_counter() - # iterator_stop = torch.tensor(0).to(self.device) + iterator_stop = torch.tensor(0).to(self.device) dataloader_val.batch_sampler.set_epoch(epoch) for batch_idx, batch in enumerate(dataloader_val): # if self.use_ddp or self.use_fsdp: @@ -442,7 +444,7 @@ class Trainer: self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size - batch_num_epoch = -1 + batch_num_epoch = 1 if hasattr(dataloader_val, "__len__"): batch_num_epoch = len(dataloader_val) self.log(epoch, batch_idx, @@ -455,16 +457,17 @@ class Trainer: tag="val", ) - # else: - # if self.use_ddp or self.use_fsdp: - # iterator_stop.fill_(1) - # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) + else: + if self.use_ddp or self.use_fsdp: + iterator_stop.fill_(1) + dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) self.val_acc_list.append(self.val_acc_avg) model.train() if self.use_ddp or self.use_fsdp: dist.barrier() + iterator_stop = torch.tensor(0).to(self.device) def log(self,