mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
ds
This commit is contained in:
parent
3de70601df
commit
bbd300a911
@ -198,7 +198,9 @@ def main(**kwargs):
|
|||||||
trainer.train_loss_avg = 0.0
|
trainer.train_loss_avg = 0.0
|
||||||
|
|
||||||
if trainer.rank == 0:
|
if trainer.rank == 0:
|
||||||
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
|
average_checkpoints(
|
||||||
|
trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed
|
||||||
|
)
|
||||||
|
|
||||||
trainer.close()
|
trainer.close()
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from collections import OrderedDict
|
|||||||
from functools import cmp_to_key
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
|
||||||
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
|
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Get the paths of the last 'last_n' checkpoints by parsing filenames
|
Get the paths of the last 'last_n' checkpoints by parsing filenames
|
||||||
in the output directory.
|
in the output directory.
|
||||||
@ -55,7 +55,7 @@ def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
|
|||||||
Average the last 'last_n' checkpoints' model state_dicts.
|
Average the last 'last_n' checkpoints' model state_dicts.
|
||||||
If a tensor is of type torch.int, perform sum instead of average.
|
If a tensor is of type torch.int, perform sum instead of average.
|
||||||
"""
|
"""
|
||||||
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
|
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
|
||||||
print(f"average_checkpoints: {checkpoint_paths}")
|
print(f"average_checkpoints: {checkpoint_paths}")
|
||||||
state_dicts = []
|
state_dicts = []
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user