mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
e98e10639d
commit
1c2eb051cd
@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
@ -80,7 +81,7 @@ class Trainer:
|
||||
}
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
|
||||
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
||||
torch.save(state, filename)
|
||||
print(f'Checkpoint saved to {filename}')
|
||||
|
||||
@ -110,8 +111,9 @@ class Trainer:
|
||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||
self._train_epoch(epoch)
|
||||
# self._validate_epoch(epoch)
|
||||
self._save_checkpoint(epoch)
|
||||
self.scheduler.step()
|
||||
if dist.get_rank() == 0:
|
||||
self._save_checkpoint(epoch)
|
||||
# self.scheduler.step()
|
||||
|
||||
def _train_epoch(self, epoch):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user