FunASR/funasr/train_utils/average_nbest_models.py
2023-12-15 23:46:41 +08:00

126 lines
4.5 KiB
Python

import logging
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Union
import warnings
import os
from io import BytesIO
import torch
from typing import Collection
from funasr.train.reporter import Reporter
@torch.no_grad()
def average_nbest_models(
output_dir: Path,
reporter: Reporter,
best_model_criterion: Sequence[Sequence[str]],
nbest: Union[Collection[int], int],
suffix: Optional[str] = None,
oss_bucket=None,
pai_output_dir=None,
) -> None:
"""Generate averaged model from n-best models
Args:
output_dir: The directory contains the model file for each epoch
reporter: Reporter instance
best_model_criterion: Give criterions to decide the best model.
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
nbest: Number of best model files to be averaged
suffix: A suffix added to the averaged model file name
"""
if isinstance(nbest, int):
nbests = [nbest]
else:
nbests = list(nbest)
if len(nbests) == 0:
warnings.warn("At least 1 nbest values are required")
nbests = [1]
if suffix is not None:
suffix = suffix + "."
else:
suffix = ""
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
nbest_epochs = [
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
for ph, k, m in best_model_criterion
if reporter.has(ph, k)
]
_loaded = {}
for ph, cr, epoch_and_values in nbest_epochs:
_nbests = [i for i in nbests if i <= len(epoch_and_values)]
if len(_nbests) == 0:
_nbests = [1]
for n in _nbests:
if n == 0:
continue
elif n == 1:
# The averaged model is same as the best model
e, _ = epoch_and_values[0]
op = output_dir / f"{e}epoch.pb"
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)
else:
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
logging.info(
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
)
avg = None
# 2.a. Averaging model
for e, _ in epoch_and_values[:n]:
if e not in _loaded:
if oss_bucket is None:
_loaded[e] = torch.load(
output_dir / f"{e}epoch.pb",
map_location="cpu",
)
else:
buffer = BytesIO(
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
_loaded[e] = torch.load(buffer)
states = _loaded[e]
if avg is None:
avg = states
else:
# Accumulated
for k in avg:
avg[k] = avg[k] + states[k]
for k in avg:
if str(avg[k].dtype).startswith("torch.int"):
# For int type, not averaged, but only accumulated.
# e.g. BatchNorm.num_batches_tracked
# (If there are any cases that requires averaging
# or the other reducing method, e.g. max/min, for integer type,
# please report.)
pass
else:
avg[k] = avg[k] / n
# 2.b. Save the ave model and create a symlink
if oss_bucket is None:
torch.save(avg, op)
else:
buffer = BytesIO()
torch.save(avg, buffer)
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
buffer.getvalue())
# 3. *.*.ave.pb is a symlink to the max ave model
if oss_bucket is None:
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)