Merge pull request #376 from alibaba-damo-academy/dev_wjm3

update saving model by num_update
This commit is contained in:
hnluo 2023-04-18 16:52:05 +08:00 committed by GitHub
commit 975ac752f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 9 deletions

View File

@ -467,7 +467,7 @@ class AbsTask(ABC):
parser.add_argument(
"--batch_interval",
type=int,
default=10000,
default=-1,
help="The batch interval for saving model.",
)
group.add_argument(

View File

@ -571,8 +571,7 @@ class Trainer:
#ouput dir
output_dir = Path(options.output_dir)
#batch interval
batch_interval = options.batch_interval
assert batch_interval > 0
batch_interval = options.batch_interval
start_time = time.perf_counter()
for iiter, (_, batch) in enumerate(
@ -580,14 +579,17 @@ class Trainer:
):
assert isinstance(batch, dict), type(batch)
if rank == 0:
if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
buffer = BytesIO()
torch.save(model.state_dict(), buffer)
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
if num_batch_updates % batch_interval == 0:
if options.use_pai and options.oss_bucket is not None:
buffer = BytesIO()
torch.save(model.state_dict(), buffer)
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
else:
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0: