support resume model from pai (#544)

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
This commit is contained in:
aky15 2023-05-24 14:04:54 +08:00 committed by GitHub
parent 21bcf7085b
commit 2f9685797b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -143,11 +143,23 @@ class Trainer:
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
ngpu: int = 0,
oss_bucket=None,
):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
)
if oss_bucket is None:
if os.path.exists(checkpoint):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
)
else:
return 0
else:
if oss_bucket.object_exists(checkpoint):
buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
else:
return 0
model.load_state_dict(states["model"])
reporter.load_state_dict(states["reporter"])
for optimizer, state in zip(optimizers, states["optimizers"]):
@ -206,15 +218,16 @@ class Trainer:
else:
scaler = None
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
if trainer_options.resume:
cls.resume(
checkpoint=output_dir / "checkpoint.pb",
checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
model=model,
optimizers=optimizers,
schedulers=schedulers,
reporter=reporter,
scaler=scaler,
ngpu=trainer_options.ngpu,
oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
)
start_epoch = reporter.get_epoch() + 1