This commit is contained in:
游雁 2024-03-24 15:11:02 +08:00
parent ed22e34d65
commit ed952ff630
3 changed files with 11 additions and 14 deletions

View File

@ -150,8 +150,8 @@ def main(**kwargs):
# dataset # dataset
logging.info("Build dataloader") logging.info("Build dataloader")
dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
# dataloader = dataloader_class(**kwargs) dataloader = dataloader_class(**kwargs)
dataloader_tr, dataloader_val = dataloader_class(**kwargs) # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
trainer = Trainer(local_rank=local_rank, trainer = Trainer(local_rank=local_rank,
use_ddp=use_ddp, use_ddp=use_ddp,
use_fsdp=use_fsdp, use_fsdp=use_fsdp,
@ -181,7 +181,7 @@ def main(**kwargs):
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter() time1 = time.perf_counter()
with context: with context:
# dataloader_tr, dataloader_val = dataloader.build_iter(epoch) dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
trainer.train_epoch( trainer.train_epoch(
model=model, model=model,
optim=optim, optim=optim,

View File

@ -4,7 +4,7 @@ import torch
from funasr.register import tables from funasr.register import tables
@tables.register("dataloader_classes", "DataloaderMapStyle") # @tables.register("dataloader_classes", "DataloaderMapStyle")
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs): def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
# dataset # dataset
logging.info("Build dataloader") logging.info("Build dataloader")
@ -25,7 +25,7 @@ def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
return dataloader_tr, dataloader_val return dataloader_tr, dataloader_val
# @tables.register("dataloader_classes", "DataloaderMapStyle") @tables.register("dataloader_classes", "DataloaderMapStyle")
class DataloaderMapStyle: class DataloaderMapStyle:
def __init__(self, frontend=None, tokenizer=None, **kwargs): def __init__(self, frontend=None, tokenizer=None, **kwargs):
# dataset # dataset

View File

@ -371,8 +371,7 @@ class Trainer:
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
iterator_stop = torch.tensor(0).to(self.device)
@ -402,12 +401,10 @@ class Trainer:
iterator_stop = torch.tensor(0).to(self.device) iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch) dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val): for batch_idx, batch in enumerate(dataloader_val):
# if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# if epoch >= 1: if iterator_stop > 0:
# print(f"iterator_stop: {iterator_stop}\n") break
# if iterator_stop > 0:
# break
time1 = time.perf_counter() time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}" speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
@ -467,7 +464,7 @@ class Trainer:
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
iterator_stop = torch.tensor(0).to(self.device) iterator_stop = torch.tensor(0).to(self.device)
def log(self, def log(self,