This commit is contained in:
嘉渊 2023-04-23 17:47:12 +08:00
parent 887039e9d3
commit ccd4c4d240

View File

@ -128,16 +128,15 @@ class Trainer:
"""Reserved for future development of another Trainer"""
pass
@staticmethod
def resume(
checkpoint: Union[str, Path],
model: torch.nn.Module,
reporter: Reporter,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
ngpu: int = 0,
):
def resume(self,
checkpoint: Union[str, Path],
model: torch.nn.Module,
reporter: Reporter,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
ngpu: int = 0,
):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@ -800,3 +799,26 @@ class Trainer:
if distributed:
iterator_stop.fill_(1)
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def build_trainer(
args,
model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_dataloader: AbsIterFactory,
valid_dataloader: AbsIterFactory,
trainer_options,
distributed_option: DistributedOption
):
trainer = Trainer(
args=args,
model=model,
optimizers=optimizers,
schedulers=schedulers,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
trainer_options=trainer_options,
distributed_option=distributed_option
)
return trainer