From a1a79bbe3e971a00bc315d011a2e0764b3bc3111 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 18 Apr 2023 16:41:04 +0800 Subject: [PATCH 1/2] update --- funasr/train/trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index b12bdeda0..2260f0064 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -583,11 +583,14 @@ class Trainer: if rank == 0: if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() - if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: - buffer = BytesIO() - torch.save(model.state_dict(), buffer) - options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) - + if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None): + if options.use_pai: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) + else: + torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) + if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: From 3c4ee89de9cc9dd0a5abc8ffce872296906138ef Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 18 Apr 2023 16:46:32 +0800 Subject: [PATCH 2/2] update --- funasr/tasks/abs_task.py | 2 +- funasr/train/trainer.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 777513e7e..3d2004c2d 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -467,7 +467,7 @@ class AbsTask(ABC): parser.add_argument( "--batch_interval", type=int, - default=10000, + default=-1, help="The batch interval for saving model.", ) group.add_argument( diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index 2260f0064..9574a0dad 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -571,8 +571,7 @@ class Trainer: #ouput dir output_dir = Path(options.output_dir) #batch interval - batch_interval = options.batch_interval - assert batch_interval > 0 + batch_interval = options.batch_interval start_time = time.perf_counter() for iiter, (_, batch) in enumerate( @@ -580,11 +579,11 @@ class Trainer: ): assert isinstance(batch, dict), type(batch) - if rank == 0: + if batch_interval > 0 and (not distributed_option.distributed or rank == 0): if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() - if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None): - if options.use_pai: + if num_batch_updates % batch_interval == 0: + if options.use_pai and options.oss_bucket is not None: buffer = BytesIO() torch.save(model.state_dict(), buffer) options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())