mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
support resume model from pai (#544)
Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
This commit is contained in:
parent
21bcf7085b
commit
2f9685797b
@ -143,11 +143,23 @@ class Trainer:
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
ngpu: int = 0,
|
||||
oss_bucket=None,
|
||||
):
|
||||
states = torch.load(
|
||||
checkpoint,
|
||||
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
|
||||
)
|
||||
if oss_bucket is None:
|
||||
if os.path.exists(checkpoint):
|
||||
states = torch.load(
|
||||
checkpoint,
|
||||
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
|
||||
)
|
||||
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
if oss_bucket.object_exists(checkpoint):
|
||||
buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
|
||||
states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
|
||||
else:
|
||||
return 0
|
||||
model.load_state_dict(states["model"])
|
||||
reporter.load_state_dict(states["reporter"])
|
||||
for optimizer, state in zip(optimizers, states["optimizers"]):
|
||||
@ -206,15 +218,16 @@ class Trainer:
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
|
||||
if trainer_options.resume:
|
||||
cls.resume(
|
||||
checkpoint=output_dir / "checkpoint.pb",
|
||||
checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
reporter=reporter,
|
||||
scaler=scaler,
|
||||
ngpu=trainer_options.ngpu,
|
||||
oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
|
||||
)
|
||||
|
||||
start_epoch = reporter.get_epoch() + 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user