From 1c2eb051cdcc6890af9ba64b10b9a0152288469a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 6 Dec 2023 19:45:49 +0800 Subject: [PATCH] funasr2 --- funasr/cli/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py index 74e058f71..ee5af0fdc 100644 --- a/funasr/cli/trainer.py +++ b/funasr/cli/trainer.py @@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device import logging from tqdm import tqdm from contextlib import nullcontext +import torch.distributed as dist class Trainer: """ @@ -80,7 +81,7 @@ class Trainer: } # Create output directory if it does not exist os.makedirs(self.output_dir, exist_ok=True) - filename = os.path.join(self.output_dir, f'model.{epoch}.pb') + filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') torch.save(state, filename) print(f'Checkpoint saved to {filename}') @@ -110,8 +111,9 @@ class Trainer: for epoch in range(self.start_epoch, self.max_epoch + 1): self._train_epoch(epoch) # self._validate_epoch(epoch) - self._save_checkpoint(epoch) - self.scheduler.step() + if dist.get_rank() == 0: + self._save_checkpoint(epoch) + # self.scheduler.step() def _train_epoch(self, epoch): """