From 3c4ee89de9cc9dd0a5abc8ffce872296906138ef Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 18 Apr 2023 16:46:32 +0800 Subject: [PATCH] update --- funasr/tasks/abs_task.py | 2 +- funasr/train/trainer.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 777513e7e..3d2004c2d 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -467,7 +467,7 @@ class AbsTask(ABC): parser.add_argument( "--batch_interval", type=int, - default=10000, + default=-1, help="The batch interval for saving model.", ) group.add_argument( diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index 2260f0064..9574a0dad 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -571,8 +571,7 @@ class Trainer: #ouput dir output_dir = Path(options.output_dir) #batch interval - batch_interval = options.batch_interval - assert batch_interval > 0 + batch_interval = options.batch_interval start_time = time.perf_counter() for iiter, (_, batch) in enumerate( @@ -580,11 +579,11 @@ class Trainer: ): assert isinstance(batch, dict), type(batch) - if rank == 0: + if batch_interval > 0 and (not distributed_option.distributed or rank == 0): if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() - if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None): - if options.use_pai: + if num_batch_updates % batch_interval == 0: + if options.use_pai and options.oss_bucket is not None: buffer = BytesIO() torch.save(model.state_dict(), buffer) options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())