This commit is contained in:
游雁 2024-05-20 15:27:24 +08:00
parent 3de70601df
commit bbd300a911
2 changed files with 5 additions and 3 deletions

View File

@ -198,7 +198,9 @@ def main(**kwargs):
trainer.train_loss_avg = 0.0 trainer.train_loss_avg = 0.0
if trainer.rank == 0: if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) average_checkpoints(
trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed
)
trainer.close() trainer.close()

View File

@ -16,7 +16,7 @@ from collections import OrderedDict
from functools import cmp_to_key from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False): def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
""" """
Get the paths of the last 'last_n' checkpoints by parsing filenames Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory. in the output directory.
@ -55,7 +55,7 @@ def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
Average the last 'last_n' checkpoints' model state_dicts. Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average. If a tensor is of type torch.int, perform sum instead of average.
""" """
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
print(f"average_checkpoints: {checkpoint_paths}") print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = [] state_dicts = []