mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
ds
This commit is contained in:
parent
3de70601df
commit
bbd300a911
@ -198,7 +198,9 @@ def main(**kwargs):
|
||||
trainer.train_loss_avg = 0.0
|
||||
|
||||
if trainer.rank == 0:
|
||||
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
|
||||
average_checkpoints(
|
||||
trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed
|
||||
)
|
||||
|
||||
trainer.close()
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from collections import OrderedDict
|
||||
from functools import cmp_to_key
|
||||
|
||||
|
||||
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
|
||||
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
|
||||
"""
|
||||
Get the paths of the last 'last_n' checkpoints by parsing filenames
|
||||
in the output directory.
|
||||
@ -55,7 +55,7 @@ def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
|
||||
Average the last 'last_n' checkpoints' model state_dicts.
|
||||
If a tensor is of type torch.int, perform sum instead of average.
|
||||
"""
|
||||
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
|
||||
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
|
||||
print(f"average_checkpoints: {checkpoint_paths}")
|
||||
state_dicts = []
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user