From 37c45ee8d7e4db18d95677c203b3432f3e6dde80 Mon Sep 17 00:00:00 2001 From: nichongjia-2007 Date: Thu, 23 Mar 2023 18:33:21 +0800 Subject: [PATCH] add batch interval for saving model --- funasr/tasks/asr.py | 6 ++++++ funasr/train/trainer.py | 30 +++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index e15147332..6e0f16acf 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -412,6 +412,12 @@ class ASRTask(AbsTask): default="13_15", help="The range of noise decibel level.", ) + parser.add_argument( + "--batch_interval", + type=int, + default=10000, + help="The batch interval for saving model.", + ) for class_choices in cls.class_choices_list: # Append -- and --_conf. diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index efe2009c4..4fbdcd94e 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -94,7 +94,7 @@ class TrainerOptions: wandb_model_log_interval: int use_pai: bool oss_bucket: Union[oss2.Bucket, None] - + batch_interval: int class Trainer: """Trainer having a optimizer. @@ -186,7 +186,10 @@ class Trainer: logging.warning("No keep_nbest_models is given. Change to [1]") trainer_options.keep_nbest_models = [1] keep_nbest_models = trainer_options.keep_nbest_models - + + #assert batch_interval is set and >0 + assert trainer_options.batch_interval > 0 + output_dir = Path(trainer_options.output_dir) reporter = Reporter() if trainer_options.use_amp: @@ -560,13 +563,30 @@ class Trainer: # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") - + + #get the rank + rank = distributed_option.dist_rank + #get the num batch updates + num_batch_updates = 0 + #ouput dir + output_dir = Path(options.output_dir) + #batch interval + batch_interval = options.batch_interval + assert batch_interval > 0 + start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1 ): assert isinstance(batch, dict), type(batch) - + + if rank == 0 and hasattr(model.module, "num_updates"): + num_batch_updates = model.module.get_num_updates() + if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) + if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: @@ -811,4 +831,4 @@ class Trainer: else: if distributed: iterator_stop.fill_(1) - torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) \ No newline at end of file + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)