mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ccd4c4d240
commit
a4ab665d30
@ -12,6 +12,7 @@ from funasr.build_utils.build_distributed import build_distributed
|
||||
from funasr.build_utils.build_model import build_model
|
||||
from funasr.build_utils.build_optimizer import build_optimizer
|
||||
from funasr.build_utils.build_scheduler import build_scheduler
|
||||
from funasr.build_utils.build_trainer import build_trainer
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.model_summary import model_summary
|
||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||
@ -443,4 +444,18 @@ if __name__ == '__main__':
|
||||
else:
|
||||
yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
|
||||
|
||||
# dataloader for training/validation
|
||||
train_dataloader, valid_dataloader = build_dataloader(args)
|
||||
|
||||
# Trainer, including model, optimizers, etc.
|
||||
trainer = build_trainer(
|
||||
args=args,
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
distributed_option=distributed_option
|
||||
)
|
||||
|
||||
trainer.run()
|
||||
|
||||
@ -107,7 +107,6 @@ class Trainer:
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
train_dataloader: AbsIterFactory,
|
||||
valid_dataloader: AbsIterFactory,
|
||||
trainer_options,
|
||||
distributed_option: DistributedOption):
|
||||
self.trainer_options = self.build_options(args)
|
||||
self.model = model
|
||||
@ -115,7 +114,6 @@ class Trainer:
|
||||
self.schedulers = schedulers
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.trainer_options = trainer_options
|
||||
self.distributed_option = distributed_option
|
||||
|
||||
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
|
||||
@ -808,7 +806,6 @@ def build_trainer(
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
train_dataloader: AbsIterFactory,
|
||||
valid_dataloader: AbsIterFactory,
|
||||
trainer_options,
|
||||
distributed_option: DistributedOption
|
||||
):
|
||||
trainer = Trainer(
|
||||
@ -818,7 +815,6 @@ def build_trainer(
|
||||
schedulers=schedulers,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
trainer_options=trainer_options,
|
||||
distributed_option=distributed_option
|
||||
)
|
||||
return trainer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user