From a1a79bbe3e971a00bc315d011a2e0764b3bc3111 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 18 Apr 2023 16:41:04 +0800 Subject: [PATCH] update --- funasr/train/trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index b12bdeda0..2260f0064 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -583,11 +583,14 @@ class Trainer: if rank == 0: if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() - if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: - buffer = BytesIO() - torch.save(model.state_dict(), buffer) - options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) - + if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None): + if options.use_pai: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) + else: + torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) + if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: