mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
887039e9d3
commit
ccd4c4d240
@ -128,16 +128,15 @@ class Trainer:
|
||||
"""Reserved for future development of another Trainer"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def resume(
|
||||
checkpoint: Union[str, Path],
|
||||
model: torch.nn.Module,
|
||||
reporter: Reporter,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
ngpu: int = 0,
|
||||
):
|
||||
def resume(self,
|
||||
checkpoint: Union[str, Path],
|
||||
model: torch.nn.Module,
|
||||
reporter: Reporter,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
ngpu: int = 0,
|
||||
):
|
||||
states = torch.load(
|
||||
checkpoint,
|
||||
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
|
||||
@ -800,3 +799,26 @@ class Trainer:
|
||||
if distributed:
|
||||
iterator_stop.fill_(1)
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
|
||||
|
||||
def build_trainer(
|
||||
args,
|
||||
model: FunASRModel,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
train_dataloader: AbsIterFactory,
|
||||
valid_dataloader: AbsIterFactory,
|
||||
trainer_options,
|
||||
distributed_option: DistributedOption
|
||||
):
|
||||
trainer = Trainer(
|
||||
args=args,
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
trainer_options=trainer_options,
|
||||
distributed_option=distributed_option
|
||||
)
|
||||
return trainer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user