This commit is contained in:
嘉渊 2023-04-24 10:00:56 +08:00
parent ccd4c4d240
commit a4ab665d30
2 changed files with 15 additions and 4 deletions

View File

@ -12,6 +12,7 @@ from funasr.build_utils.build_distributed import build_distributed
from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@ -443,4 +444,18 @@ if __name__ == '__main__':
else:
yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
# dataloader for training/validation
train_dataloader, valid_dataloader = build_dataloader(args)
# Trainer, including model, optimizers, etc.
trainer = build_trainer(
args=args,
model=model,
optimizers=optimizers,
schedulers=schedulers,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
distributed_option=distributed_option
)
trainer.run()

View File

@ -107,7 +107,6 @@ class Trainer:
schedulers: Sequence[Optional[AbsScheduler]],
train_dataloader: AbsIterFactory,
valid_dataloader: AbsIterFactory,
trainer_options,
distributed_option: DistributedOption):
self.trainer_options = self.build_options(args)
self.model = model
@ -115,7 +114,6 @@ class Trainer:
self.schedulers = schedulers
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.trainer_options = trainer_options
self.distributed_option = distributed_option
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
@ -808,7 +806,6 @@ def build_trainer(
schedulers: Sequence[Optional[AbsScheduler]],
train_dataloader: AbsIterFactory,
valid_dataloader: AbsIterFactory,
trainer_options,
distributed_option: DistributedOption
):
trainer = Trainer(
@ -818,7 +815,6 @@ def build_trainer(
schedulers=schedulers,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
trainer_options=trainer_options,
distributed_option=distributed_option
)
return trainer