mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Union
|
|
import warnings
|
|
import os
|
|
from io import BytesIO
|
|
|
|
import torch
|
|
from typing import Collection
|
|
|
|
from funasr.train.reporter import Reporter
|
|
|
|
|
|
@torch.no_grad()
|
|
def average_nbest_models(
|
|
output_dir: Path,
|
|
reporter: Reporter,
|
|
best_model_criterion: Sequence[Sequence[str]],
|
|
nbest: Union[Collection[int], int],
|
|
suffix: Optional[str] = None,
|
|
oss_bucket=None,
|
|
pai_output_dir=None,
|
|
) -> None:
|
|
"""Generate averaged model from n-best models
|
|
|
|
Args:
|
|
output_dir: The directory contains the model file for each epoch
|
|
reporter: Reporter instance
|
|
best_model_criterion: Give criterions to decide the best model.
|
|
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
|
|
nbest: Number of best model files to be averaged
|
|
suffix: A suffix added to the averaged model file name
|
|
"""
|
|
if isinstance(nbest, int):
|
|
nbests = [nbest]
|
|
else:
|
|
nbests = list(nbest)
|
|
if len(nbests) == 0:
|
|
warnings.warn("At least 1 nbest values are required")
|
|
nbests = [1]
|
|
if suffix is not None:
|
|
suffix = suffix + "."
|
|
else:
|
|
suffix = ""
|
|
|
|
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
|
|
nbest_epochs = [
|
|
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
|
|
for ph, k, m in best_model_criterion
|
|
if reporter.has(ph, k)
|
|
]
|
|
|
|
_loaded = {}
|
|
for ph, cr, epoch_and_values in nbest_epochs:
|
|
_nbests = [i for i in nbests if i <= len(epoch_and_values)]
|
|
if len(_nbests) == 0:
|
|
_nbests = [1]
|
|
|
|
for n in _nbests:
|
|
if n == 0:
|
|
continue
|
|
elif n == 1:
|
|
# The averaged model is same as the best model
|
|
e, _ = epoch_and_values[0]
|
|
op = output_dir / f"{e}epoch.pb"
|
|
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
|
if sym_op.is_symlink() or sym_op.exists():
|
|
sym_op.unlink()
|
|
sym_op.symlink_to(op.name)
|
|
else:
|
|
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
|
logging.info(
|
|
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
|
)
|
|
|
|
avg = None
|
|
# 2.a. Averaging model
|
|
for e, _ in epoch_and_values[:n]:
|
|
if e not in _loaded:
|
|
if oss_bucket is None:
|
|
_loaded[e] = torch.load(
|
|
output_dir / f"{e}epoch.pb",
|
|
map_location="cpu",
|
|
)
|
|
else:
|
|
buffer = BytesIO(
|
|
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
|
_loaded[e] = torch.load(buffer)
|
|
states = _loaded[e]
|
|
|
|
if avg is None:
|
|
avg = states
|
|
else:
|
|
# Accumulated
|
|
for k in avg:
|
|
avg[k] = avg[k] + states[k]
|
|
for k in avg:
|
|
if str(avg[k].dtype).startswith("torch.int"):
|
|
# For int type, not averaged, but only accumulated.
|
|
# e.g. BatchNorm.num_batches_tracked
|
|
# (If there are any cases that requires averaging
|
|
# or the other reducing method, e.g. max/min, for integer type,
|
|
# please report.)
|
|
pass
|
|
else:
|
|
avg[k] = avg[k] / n
|
|
|
|
# 2.b. Save the ave model and create a symlink
|
|
if oss_bucket is None:
|
|
torch.save(avg, op)
|
|
else:
|
|
buffer = BytesIO()
|
|
torch.save(avg, buffer)
|
|
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
|
buffer.getvalue())
|
|
|
|
# 3. *.*.ave.pb is a symlink to the max ave model
|
|
if oss_bucket is None:
|
|
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
|
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
|
if sym_op.is_symlink() or sym_op.exists():
|
|
sym_op.unlink()
|
|
sym_op.symlink_to(op.name)
|