mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
finetune
This commit is contained in:
parent
ed22e34d65
commit
ed952ff630
@ -150,8 +150,8 @@ def main(**kwargs):
|
||||
# dataset
|
||||
logging.info("Build dataloader")
|
||||
dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
|
||||
# dataloader = dataloader_class(**kwargs)
|
||||
dataloader_tr, dataloader_val = dataloader_class(**kwargs)
|
||||
dataloader = dataloader_class(**kwargs)
|
||||
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
|
||||
trainer = Trainer(local_rank=local_rank,
|
||||
use_ddp=use_ddp,
|
||||
use_fsdp=use_fsdp,
|
||||
@ -181,7 +181,7 @@ def main(**kwargs):
|
||||
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
|
||||
time1 = time.perf_counter()
|
||||
with context:
|
||||
# dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
|
||||
dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
|
||||
trainer.train_epoch(
|
||||
model=model,
|
||||
optim=optim,
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||
# @tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
|
||||
# dataset
|
||||
logging.info("Build dataloader")
|
||||
@ -25,7 +25,7 @@ def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
|
||||
|
||||
return dataloader_tr, dataloader_val
|
||||
|
||||
# @tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||
@tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||
class DataloaderMapStyle:
|
||||
def __init__(self, frontend=None, tokenizer=None, **kwargs):
|
||||
# dataset
|
||||
|
||||
@ -371,8 +371,7 @@ class Trainer:
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
|
||||
|
||||
@ -402,12 +401,10 @@ class Trainer:
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
dataloader_val.batch_sampler.set_epoch(epoch)
|
||||
for batch_idx, batch in enumerate(dataloader_val):
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
# if epoch >= 1:
|
||||
# print(f"iterator_stop: {iterator_stop}\n")
|
||||
# if iterator_stop > 0:
|
||||
# break
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
||||
batch = to_device(batch, self.device)
|
||||
@ -467,7 +464,7 @@ class Trainer:
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
|
||||
def log(self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user