diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 777513e7e..3d2004c2d 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -467,7 +467,7 @@ class AbsTask(ABC): parser.add_argument( "--batch_interval", type=int, - default=10000, + default=-1, help="The batch interval for saving model.", ) group.add_argument( diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index b12bdeda0..9574a0dad 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -571,8 +571,7 @@ class Trainer: #ouput dir output_dir = Path(options.output_dir) #batch interval - batch_interval = options.batch_interval - assert batch_interval > 0 + batch_interval = options.batch_interval start_time = time.perf_counter() for iiter, (_, batch) in enumerate( @@ -580,14 +579,17 @@ class Trainer: ): assert isinstance(batch, dict), type(batch) - if rank == 0: + if batch_interval > 0 and (not distributed_option.distributed or rank == 0): if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() - if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: - buffer = BytesIO() - torch.save(model.state_dict(), buffer) - options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) - + if num_batch_updates % batch_interval == 0: + if options.use_pai and options.oss_bucket is not None: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) + else: + torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) + if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: