This commit is contained in:
speech_asr 2023-04-18 16:46:32 +08:00
parent a1a79bbe3e
commit 3c4ee89de9
2 changed files with 5 additions and 6 deletions

View File

@ -467,7 +467,7 @@ class AbsTask(ABC):
parser.add_argument(
"--batch_interval",
type=int,
default=10000,
default=-1,
help="The batch interval for saving model.",
)
group.add_argument(

View File

@ -571,8 +571,7 @@ class Trainer:
#ouput dir
output_dir = Path(options.output_dir)
#batch interval
batch_interval = options.batch_interval
assert batch_interval > 0
batch_interval = options.batch_interval
start_time = time.perf_counter()
for iiter, (_, batch) in enumerate(
@ -580,11 +579,11 @@ class Trainer:
):
assert isinstance(batch, dict), type(batch)
if rank == 0:
if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None):
if options.use_pai:
if num_batch_updates % batch_interval == 0:
if options.use_pai and options.oss_bucket is not None:
buffer = BytesIO()
torch.save(model.state_dict(), buffer)
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())