diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py index 55bc89c25..dd592e437 100644 --- a/funasr/build_utils/build_trainer.py +++ b/funasr/build_utils/build_trainer.py @@ -128,16 +128,15 @@ class Trainer: """Reserved for future development of another Trainer""" pass - @staticmethod - def resume( - checkpoint: Union[str, Path], - model: torch.nn.Module, - reporter: Reporter, - optimizers: Sequence[torch.optim.Optimizer], - schedulers: Sequence[Optional[AbsScheduler]], - scaler: Optional[GradScaler], - ngpu: int = 0, - ): + def resume(self, + checkpoint: Union[str, Path], + model: torch.nn.Module, + reporter: Reporter, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + scaler: Optional[GradScaler], + ngpu: int = 0, + ): states = torch.load( checkpoint, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", @@ -800,3 +799,26 @@ class Trainer: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + + +def build_trainer( + args, + model: FunASRModel, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + train_dataloader: AbsIterFactory, + valid_dataloader: AbsIterFactory, + trainer_options, + distributed_option: DistributedOption +): + trainer = Trainer( + args=args, + model=model, + optimizers=optimizers, + schedulers=schedulers, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + trainer_options=trainer_options, + distributed_option=distributed_option + ) + return trainer