This commit is contained in:
游雁 2023-12-06 19:45:49 +08:00
parent e98e10639d
commit 1c2eb051cd

View File

@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
class Trainer:
"""
@ -80,7 +81,7 @@ class Trainer:
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
torch.save(state, filename)
print(f'Checkpoint saved to {filename}')
@ -110,8 +111,9 @@ class Trainer:
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
self._save_checkpoint(epoch)
self.scheduler.step()
if dist.get_rank() == 0:
self._save_checkpoint(epoch)
# self.scheduler.step()
def _train_epoch(self, epoch):
"""