mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
finetune
This commit is contained in:
parent
ed22e34d65
commit
ed952ff630
@ -150,8 +150,8 @@ def main(**kwargs):
|
|||||||
# dataset
|
# dataset
|
||||||
logging.info("Build dataloader")
|
logging.info("Build dataloader")
|
||||||
dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
|
dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
|
||||||
# dataloader = dataloader_class(**kwargs)
|
dataloader = dataloader_class(**kwargs)
|
||||||
dataloader_tr, dataloader_val = dataloader_class(**kwargs)
|
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
|
||||||
trainer = Trainer(local_rank=local_rank,
|
trainer = Trainer(local_rank=local_rank,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
use_fsdp=use_fsdp,
|
use_fsdp=use_fsdp,
|
||||||
@ -181,7 +181,7 @@ def main(**kwargs):
|
|||||||
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
|
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
|
||||||
time1 = time.perf_counter()
|
time1 = time.perf_counter()
|
||||||
with context:
|
with context:
|
||||||
# dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
|
dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
|
||||||
trainer.train_epoch(
|
trainer.train_epoch(
|
||||||
model=model,
|
model=model,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@tables.register("dataloader_classes", "DataloaderMapStyle")
|
# @tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||||
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
|
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
|
||||||
# dataset
|
# dataset
|
||||||
logging.info("Build dataloader")
|
logging.info("Build dataloader")
|
||||||
@ -25,7 +25,7 @@ def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
|
|||||||
|
|
||||||
return dataloader_tr, dataloader_val
|
return dataloader_tr, dataloader_val
|
||||||
|
|
||||||
# @tables.register("dataloader_classes", "DataloaderMapStyle")
|
@tables.register("dataloader_classes", "DataloaderMapStyle")
|
||||||
class DataloaderMapStyle:
|
class DataloaderMapStyle:
|
||||||
def __init__(self, frontend=None, tokenizer=None, **kwargs):
|
def __init__(self, frontend=None, tokenizer=None, **kwargs):
|
||||||
# dataset
|
# dataset
|
||||||
|
|||||||
@ -371,8 +371,7 @@ class Trainer:
|
|||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
iterator_stop = torch.tensor(0).to(self.device)
|
||||||
iterator_stop = torch.tensor(0).to(self.device)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -402,12 +401,10 @@ class Trainer:
|
|||||||
iterator_stop = torch.tensor(0).to(self.device)
|
iterator_stop = torch.tensor(0).to(self.device)
|
||||||
dataloader_val.batch_sampler.set_epoch(epoch)
|
dataloader_val.batch_sampler.set_epoch(epoch)
|
||||||
for batch_idx, batch in enumerate(dataloader_val):
|
for batch_idx, batch in enumerate(dataloader_val):
|
||||||
# if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||||
# if epoch >= 1:
|
if iterator_stop > 0:
|
||||||
# print(f"iterator_stop: {iterator_stop}\n")
|
break
|
||||||
# if iterator_stop > 0:
|
|
||||||
# break
|
|
||||||
time1 = time.perf_counter()
|
time1 = time.perf_counter()
|
||||||
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
@ -467,7 +464,7 @@ class Trainer:
|
|||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
iterator_stop = torch.tensor(0).to(self.device)
|
iterator_stop = torch.tensor(0).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
def log(self,
|
def log(self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user