mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
e98e10639d
commit
1c2eb051cd
@ -4,6 +4,7 @@ from funasr.torch_utils.device_funcs import to_device
|
|||||||
import logging
|
import logging
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
@ -80,7 +81,7 @@ class Trainer:
|
|||||||
}
|
}
|
||||||
# Create output directory if it does not exist
|
# Create output directory if it does not exist
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
|
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
print(f'Checkpoint saved to {filename}')
|
print(f'Checkpoint saved to {filename}')
|
||||||
|
|
||||||
@ -110,8 +111,9 @@ class Trainer:
|
|||||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(epoch)
|
||||||
# self._validate_epoch(epoch)
|
# self._validate_epoch(epoch)
|
||||||
self._save_checkpoint(epoch)
|
if dist.get_rank() == 0:
|
||||||
self.scheduler.step()
|
self._save_checkpoint(epoch)
|
||||||
|
# self.scheduler.step()
|
||||||
|
|
||||||
def _train_epoch(self, epoch):
|
def _train_epoch(self, epoch):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user