diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 6cb486bf2..d19b79a14 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -150,8 +150,8 @@ def main(**kwargs): # dataset logging.info("Build dataloader") dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) - # dataloader = dataloader_class(**kwargs) - dataloader_tr, dataloader_val = dataloader_class(**kwargs) + dataloader = dataloader_class(**kwargs) + # dataloader_tr, dataloader_val = dataloader_class(**kwargs) trainer = Trainer(local_rank=local_rank, use_ddp=use_ddp, use_fsdp=use_fsdp, @@ -181,7 +181,7 @@ def main(**kwargs): for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): time1 = time.perf_counter() with context: - # dataloader_tr, dataloader_val = dataloader.build_iter(epoch) + dataloader_tr, dataloader_val = dataloader.build_iter(epoch) trainer.train_epoch( model=model, optim=optim, diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py index 0de7e4053..2222efbf2 100644 --- a/funasr/datasets/dataloader_entry.py +++ b/funasr/datasets/dataloader_entry.py @@ -4,7 +4,7 @@ import torch from funasr.register import tables -@tables.register("dataloader_classes", "DataloaderMapStyle") +# @tables.register("dataloader_classes", "DataloaderMapStyle") def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs): # dataset logging.info("Build dataloader") @@ -25,7 +25,7 @@ def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs): return dataloader_tr, dataloader_val -# @tables.register("dataloader_classes", "DataloaderMapStyle") +@tables.register("dataloader_classes", "DataloaderMapStyle") class DataloaderMapStyle: def __init__(self, frontend=None, tokenizer=None, **kwargs): # dataset diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index c66539418..2d47fc17e 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -371,8 +371,7 @@ class Trainer: if self.use_ddp or self.use_fsdp: dist.barrier() - - iterator_stop = torch.tensor(0).to(self.device) + iterator_stop = torch.tensor(0).to(self.device) @@ -402,12 +401,10 @@ class Trainer: iterator_stop = torch.tensor(0).to(self.device) dataloader_val.batch_sampler.set_epoch(epoch) for batch_idx, batch in enumerate(dataloader_val): - # if self.use_ddp or self.use_fsdp: - # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) - # if epoch >= 1: - # print(f"iterator_stop: {iterator_stop}\n") - # if iterator_stop > 0: - # break + if self.use_ddp or self.use_fsdp: + dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) + if iterator_stop > 0: + break time1 = time.perf_counter() speed_stats["data_load"] = f"{time1 - time5:0.3f}" batch = to_device(batch, self.device) @@ -467,7 +464,7 @@ class Trainer: if self.use_ddp or self.use_fsdp: dist.barrier() - iterator_stop = torch.tensor(0).to(self.device) + iterator_stop = torch.tensor(0).to(self.device) def log(self,