This commit is contained in:
游雁 2024-03-24 15:11:02 +08:00
parent ed22e34d65
commit ed952ff630
3 changed files with 11 additions and 14 deletions

View File

@ -150,8 +150,8 @@ def main(**kwargs):
# dataset
logging.info("Build dataloader")
dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
# dataloader = dataloader_class(**kwargs)
dataloader_tr, dataloader_val = dataloader_class(**kwargs)
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
trainer = Trainer(local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
@ -181,7 +181,7 @@ def main(**kwargs):
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
with context:
# dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
trainer.train_epoch(
model=model,
optim=optim,

View File

@ -4,7 +4,7 @@ import torch
from funasr.register import tables
@tables.register("dataloader_classes", "DataloaderMapStyle")
# @tables.register("dataloader_classes", "DataloaderMapStyle")
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
# dataset
logging.info("Build dataloader")
@ -25,7 +25,7 @@ def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
return dataloader_tr, dataloader_val
# @tables.register("dataloader_classes", "DataloaderMapStyle")
@tables.register("dataloader_classes", "DataloaderMapStyle")
class DataloaderMapStyle:
def __init__(self, frontend=None, tokenizer=None, **kwargs):
# dataset

View File

@ -371,8 +371,7 @@ class Trainer:
if self.use_ddp or self.use_fsdp:
dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
iterator_stop = torch.tensor(0).to(self.device)
@ -402,12 +401,10 @@ class Trainer:
iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# if epoch >= 1:
# print(f"iterator_stop: {iterator_stop}\n")
# if iterator_stop > 0:
# break
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
break
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
@ -467,7 +464,7 @@ class Trainer:
if self.use_ddp or self.use_fsdp:
dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
iterator_stop = torch.tensor(0).to(self.device)
def log(self,