This commit is contained in:
游雁 2024-05-20 15:27:24 +08:00
parent 3de70601df
commit bbd300a911
2 changed files with 5 additions and 3 deletions

View File

@ -198,7 +198,9 @@ def main(**kwargs):
trainer.train_loss_avg = 0.0
if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
average_checkpoints(
trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed
)
trainer.close()

View File

@ -16,7 +16,7 @@ from collections import OrderedDict
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
"""
Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory.
@ -55,7 +55,7 @@ def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
"""
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = []