mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
7b6256c86f
commit
c652f6814a
@ -420,16 +420,16 @@ if __name__ == '__main__':
|
||||
prepare_data(args, distributed_option)
|
||||
|
||||
model = build_model(args)
|
||||
optimizer = build_optimizer(args, model=model)
|
||||
scheduler = build_scheduler(args, optimizer)
|
||||
optimizers = build_optimizer(args, model=model)
|
||||
schedulers = build_scheduler(args, optimizers)
|
||||
|
||||
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
|
||||
distributed_option.dist_rank,
|
||||
distributed_option.local_rank))
|
||||
logging.info(pytorch_cudnn_version())
|
||||
logging.info(model_summary(model))
|
||||
logging.info("Optimizer: {}".format(optimizer))
|
||||
logging.info("Scheduler: {}".format(scheduler))
|
||||
logging.info("Optimizer: {}".format(optimizers))
|
||||
logging.info("Scheduler: {}".format(schedulers))
|
||||
|
||||
# dump args to config.yaml
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
|
||||
@ -23,4 +23,6 @@ def build_optimizer(args, model):
|
||||
if optim_class is None:
|
||||
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
|
||||
optimizer = optim_class(model.parameters(), **args.optim_conf)
|
||||
return optimizer
|
||||
|
||||
optimizers = [optimizer]
|
||||
return optimizers
|
||||
@ -8,7 +8,7 @@ from funasr.schedulers.tri_stage_scheduler import TriStageLR
|
||||
from funasr.schedulers.warmup_lr import WarmupLR
|
||||
|
||||
|
||||
def build_scheduler(args, optimizer):
|
||||
def build_scheduler(args, optimizers):
|
||||
scheduler_classes = dict(
|
||||
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
||||
lambdalr=torch.optim.lr_scheduler.LambdaLR,
|
||||
@ -24,8 +24,21 @@ def build_scheduler(args, optimizer):
|
||||
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
||||
)
|
||||
|
||||
scheduler_class = scheduler_classes.get(args.scheduler)
|
||||
if scheduler_class is None:
|
||||
raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
|
||||
scheduler = scheduler_class(optimizer, **args.scheduler_conf)
|
||||
return scheduler
|
||||
schedulers = []
|
||||
for i, optim in enumerate(optimizers, 1):
|
||||
suf = "" if i == 1 else str(i)
|
||||
name = getattr(args, f"scheduler{suf}")
|
||||
conf = getattr(args, f"scheduler{suf}_conf")
|
||||
if name is not None:
|
||||
cls_ = scheduler_classes.get(name)
|
||||
if cls_ is None:
|
||||
raise ValueError(
|
||||
f"must be one of {list(scheduler_classes)}: {name}"
|
||||
)
|
||||
scheduler = cls_(optim, **conf)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
schedulers.append(scheduler)
|
||||
|
||||
return schedulers
|
||||
843
funasr/build_utils/build_trainer.py
Normal file
843
funasr/build_utils/build_trainer.py
Normal file
@ -0,0 +1,843 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Trainer module."""
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import dataclasses
|
||||
from dataclasses import is_dataclass
|
||||
from distutils.version import LooseVersion
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import oss2
|
||||
from io import BytesIO
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.main_funcs.average_nbest_models import average_nbest_models
|
||||
from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
|
||||
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
|
||||
from funasr.schedulers.abs_scheduler import AbsScheduler
|
||||
from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
|
||||
from funasr.torch_utils.add_gradient_noise import add_gradient_noise
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.recursive_op import recursive_average
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.train.distributed_utils import DistributedOption
|
||||
from funasr.train.reporter import Reporter
|
||||
from funasr.train.reporter import SubReporter
|
||||
from funasr.utils.build_dataclass import build_dataclass
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
GradScaler = None
|
||||
|
||||
try:
|
||||
import fairscale
|
||||
except ImportError:
|
||||
fairscale = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerOptions:
|
||||
ngpu: int
|
||||
resume: bool
|
||||
use_amp: bool
|
||||
train_dtype: str
|
||||
grad_noise: bool
|
||||
accum_grad: int
|
||||
grad_clip: float
|
||||
grad_clip_type: float
|
||||
log_interval: Optional[int]
|
||||
no_forward_run: bool
|
||||
use_tensorboard: bool
|
||||
use_wandb: bool
|
||||
output_dir: Union[Path, str]
|
||||
max_epoch: int
|
||||
max_update: int
|
||||
seed: int
|
||||
sharded_ddp: bool
|
||||
patience: Optional[int]
|
||||
keep_nbest_models: Union[int, List[int]]
|
||||
nbest_averaging_interval: int
|
||||
early_stopping_criterion: Sequence[str]
|
||||
best_model_criterion: Sequence[Sequence[str]]
|
||||
val_scheduler_criterion: Sequence[str]
|
||||
unused_parameters: bool
|
||||
wandb_model_log_interval: int
|
||||
use_pai: bool
|
||||
oss_bucket: Union[oss2.Bucket, None]
|
||||
batch_interval: int
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""Trainer having a optimizer.
|
||||
|
||||
If you'd like to use multiple optimizers, then inherit this class
|
||||
and override the methods if necessary - at least "train_one_epoch()"
|
||||
|
||||
>>> class TwoOptimizerTrainer(Trainer):
|
||||
... @classmethod
|
||||
... def add_arguments(cls, parser):
|
||||
... ...
|
||||
...
|
||||
... @classmethod
|
||||
... def train_one_epoch(cls, model, optimizers, ...):
|
||||
... loss1 = model.model1(...)
|
||||
... loss1.backward()
|
||||
... optimizers[0].step()
|
||||
...
|
||||
... loss2 = model.model2(...)
|
||||
... loss2.backward()
|
||||
... optimizers[1].step()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("This class can't be instantiated.")
|
||||
|
||||
@classmethod
|
||||
def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
|
||||
"""Build options consumed by train(), eval()"""
|
||||
assert check_argument_types()
|
||||
return build_dataclass(TrainerOptions, args)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
"""Reserved for future development of another Trainer"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def resume(
|
||||
checkpoint: Union[str, Path],
|
||||
model: torch.nn.Module,
|
||||
reporter: Reporter,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
ngpu: int = 0,
|
||||
):
|
||||
states = torch.load(
|
||||
checkpoint,
|
||||
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
|
||||
)
|
||||
model.load_state_dict(states["model"])
|
||||
reporter.load_state_dict(states["reporter"])
|
||||
for optimizer, state in zip(optimizers, states["optimizers"]):
|
||||
optimizer.load_state_dict(state)
|
||||
for scheduler, state in zip(schedulers, states["schedulers"]):
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(state)
|
||||
if scaler is not None:
|
||||
if states["scaler"] is None:
|
||||
logging.warning("scaler state is not found")
|
||||
else:
|
||||
scaler.load_state_dict(states["scaler"])
|
||||
|
||||
logging.info(f"The training was resumed using {checkpoint}")
|
||||
|
||||
@classmethod
|
||||
def run(
|
||||
cls,
|
||||
model: FunASRModel,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
train_iter_factory: AbsIterFactory,
|
||||
valid_iter_factory: AbsIterFactory,
|
||||
trainer_options,
|
||||
distributed_option: DistributedOption,
|
||||
) -> None:
|
||||
"""Perform training. This method performs the main process of training."""
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): Don't check the type more strictly as far trainer_options
|
||||
assert is_dataclass(trainer_options), type(trainer_options)
|
||||
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
|
||||
|
||||
if isinstance(trainer_options.keep_nbest_models, int):
|
||||
keep_nbest_models = [trainer_options.keep_nbest_models]
|
||||
else:
|
||||
if len(trainer_options.keep_nbest_models) == 0:
|
||||
logging.warning("No keep_nbest_models is given. Change to [1]")
|
||||
trainer_options.keep_nbest_models = [1]
|
||||
keep_nbest_models = trainer_options.keep_nbest_models
|
||||
|
||||
# assert batch_interval is set and >0
|
||||
assert trainer_options.batch_interval > 0
|
||||
|
||||
output_dir = Path(trainer_options.output_dir)
|
||||
reporter = Reporter()
|
||||
if trainer_options.use_amp:
|
||||
if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
|
||||
raise RuntimeError(
|
||||
"Require torch>=1.6.0 for Automatic Mixed Precision"
|
||||
)
|
||||
if trainer_options.sharded_ddp:
|
||||
if fairscale is None:
|
||||
raise RuntimeError(
|
||||
"Requiring fairscale. Do 'pip install fairscale'"
|
||||
)
|
||||
scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
||||
else:
|
||||
scaler = GradScaler()
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
|
||||
cls.resume(
|
||||
checkpoint=output_dir / "checkpoint.pb",
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
reporter=reporter,
|
||||
scaler=scaler,
|
||||
ngpu=trainer_options.ngpu,
|
||||
)
|
||||
|
||||
start_epoch = reporter.get_epoch() + 1
|
||||
if start_epoch == trainer_options.max_epoch + 1:
|
||||
logging.warning(
|
||||
f"The training has already reached at max_epoch: {start_epoch}"
|
||||
)
|
||||
|
||||
if distributed_option.distributed:
|
||||
if trainer_options.sharded_ddp:
|
||||
dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
|
||||
module=model,
|
||||
sharded_optimizer=optimizers,
|
||||
)
|
||||
else:
|
||||
dp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, find_unused_parameters=trainer_options.unused_parameters)
|
||||
elif distributed_option.ngpu > 1:
|
||||
dp_model = torch.nn.parallel.DataParallel(
|
||||
model,
|
||||
device_ids=list(range(distributed_option.ngpu)),
|
||||
)
|
||||
else:
|
||||
# NOTE(kamo): DataParallel also should work with ngpu=1,
|
||||
# but for debuggability it's better to keep this block.
|
||||
dp_model = model
|
||||
|
||||
if trainer_options.use_tensorboard and (
|
||||
not distributed_option.distributed or distributed_option.dist_rank == 0
|
||||
):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
if trainer_options.use_pai:
|
||||
train_summary_writer = SummaryWriter(
|
||||
os.path.join(trainer_options.output_dir, "tensorboard/train")
|
||||
)
|
||||
valid_summary_writer = SummaryWriter(
|
||||
os.path.join(trainer_options.output_dir, "tensorboard/valid")
|
||||
)
|
||||
else:
|
||||
train_summary_writer = SummaryWriter(
|
||||
str(output_dir / "tensorboard" / "train")
|
||||
)
|
||||
valid_summary_writer = SummaryWriter(
|
||||
str(output_dir / "tensorboard" / "valid")
|
||||
)
|
||||
else:
|
||||
train_summary_writer = None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
|
||||
if iepoch != start_epoch:
|
||||
logging.info(
|
||||
"{}/{}epoch started. Estimated time to finish: {}".format(
|
||||
iepoch,
|
||||
trainer_options.max_epoch,
|
||||
humanfriendly.format_timespan(
|
||||
(time.perf_counter() - start_time)
|
||||
/ (iepoch - start_epoch)
|
||||
* (trainer_options.max_epoch - iepoch + 1)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
|
||||
set_all_random_seed(trainer_options.seed + iepoch)
|
||||
|
||||
reporter.set_epoch(iepoch)
|
||||
# 1. Train and validation for one-epoch
|
||||
with reporter.observe("train") as sub_reporter:
|
||||
all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
|
||||
model=dp_model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
iterator=train_iter_factory.build_iter(iepoch),
|
||||
reporter=sub_reporter,
|
||||
scaler=scaler,
|
||||
summary_writer=train_summary_writer,
|
||||
options=trainer_options,
|
||||
distributed_option=distributed_option,
|
||||
)
|
||||
|
||||
with reporter.observe("valid") as sub_reporter:
|
||||
cls.validate_one_epoch(
|
||||
model=dp_model,
|
||||
iterator=valid_iter_factory.build_iter(iepoch),
|
||||
reporter=sub_reporter,
|
||||
options=trainer_options,
|
||||
distributed_option=distributed_option,
|
||||
)
|
||||
|
||||
# 2. LR Scheduler step
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, AbsValEpochStepScheduler):
|
||||
scheduler.step(
|
||||
reporter.get_value(*trainer_options.val_scheduler_criterion)
|
||||
)
|
||||
elif isinstance(scheduler, AbsEpochStepScheduler):
|
||||
scheduler.step()
|
||||
if trainer_options.sharded_ddp:
|
||||
for optimizer in optimizers:
|
||||
if isinstance(optimizer, fairscale.optim.oss.OSS):
|
||||
optimizer.consolidate_state_dict()
|
||||
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
# 3. Report the results
|
||||
logging.info(reporter.log_message())
|
||||
if train_summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
|
||||
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
|
||||
if trainer_options.use_wandb:
|
||||
reporter.wandb_log()
|
||||
|
||||
# save tensorboard on oss
|
||||
if trainer_options.use_pai and train_summary_writer is not None:
|
||||
def write_tensorboard_summary(summary_writer_path, oss_bucket):
|
||||
file_list = []
|
||||
for root, dirs, files in os.walk(summary_writer_path, topdown=False):
|
||||
for name in files:
|
||||
file_full_path = os.path.join(root, name)
|
||||
file_list.append(file_full_path)
|
||||
|
||||
for file_full_path in file_list:
|
||||
with open(file_full_path, "rb") as f:
|
||||
oss_bucket.put_object(file_full_path, f)
|
||||
|
||||
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
|
||||
trainer_options.oss_bucket)
|
||||
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
|
||||
trainer_options.oss_bucket)
|
||||
|
||||
# 4. Save/Update the checkpoint
|
||||
if trainer_options.use_pai:
|
||||
buffer = BytesIO()
|
||||
torch.save(
|
||||
{
|
||||
"model": model.state_dict(),
|
||||
"reporter": reporter.state_dict(),
|
||||
"optimizers": [o.state_dict() for o in optimizers],
|
||||
"schedulers": [
|
||||
s.state_dict() if s is not None else None
|
||||
for s in schedulers
|
||||
],
|
||||
"scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"ema_model": model.encoder.ema.model.state_dict()
|
||||
if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
|
||||
},
|
||||
buffer,
|
||||
)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
|
||||
buffer.getvalue())
|
||||
else:
|
||||
torch.save(
|
||||
{
|
||||
"model": model.state_dict(),
|
||||
"reporter": reporter.state_dict(),
|
||||
"optimizers": [o.state_dict() for o in optimizers],
|
||||
"schedulers": [
|
||||
s.state_dict() if s is not None else None
|
||||
for s in schedulers
|
||||
],
|
||||
"scaler": scaler.state_dict() if scaler is not None else None,
|
||||
},
|
||||
output_dir / "checkpoint.pb",
|
||||
)
|
||||
|
||||
# 5. Save and log the model and update the link to the best model
|
||||
if trainer_options.use_pai:
|
||||
buffer = BytesIO()
|
||||
torch.save(model.state_dict(), buffer)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
|
||||
f"{iepoch}epoch.pb"), buffer.getvalue())
|
||||
else:
|
||||
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
|
||||
|
||||
# Creates a sym link latest.pb -> {iepoch}epoch.pb
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, "latest.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir,
|
||||
f"{iepoch}epoch.pb"), p)
|
||||
else:
|
||||
p = output_dir / "latest.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
|
||||
_improved = []
|
||||
for _phase, k, _mode in trainer_options.best_model_criterion:
|
||||
# e.g. _phase, k, _mode = "train", "loss", "min"
|
||||
if reporter.has(_phase, k):
|
||||
best_epoch = reporter.get_best_epoch(_phase, k, _mode)
|
||||
# Creates sym links if it's the best result
|
||||
if best_epoch == iepoch:
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir,
|
||||
f"{iepoch}epoch.pb"), p)
|
||||
else:
|
||||
p = output_dir / f"{_phase}.{k}.best.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
_improved.append(f"{_phase}.{k}")
|
||||
if len(_improved) == 0:
|
||||
logging.info("There are no improvements in this epoch")
|
||||
else:
|
||||
logging.info(
|
||||
"The best model has been updated: " + ", ".join(_improved)
|
||||
)
|
||||
|
||||
log_model = (
|
||||
trainer_options.wandb_model_log_interval > 0
|
||||
and iepoch % trainer_options.wandb_model_log_interval == 0
|
||||
)
|
||||
if log_model and trainer_options.use_wandb:
|
||||
import wandb
|
||||
|
||||
logging.info("Logging Model on this epoch :::::")
|
||||
artifact = wandb.Artifact(
|
||||
name=f"model_{wandb.run.id}",
|
||||
type="model",
|
||||
metadata={"improved": _improved},
|
||||
)
|
||||
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
||||
aliases = [
|
||||
f"epoch-{iepoch}",
|
||||
"best" if best_epoch == iepoch else "",
|
||||
]
|
||||
wandb.log_artifact(artifact, aliases=aliases)
|
||||
|
||||
# 6. Remove the model files excluding n-best epoch and latest epoch
|
||||
_removed = []
|
||||
# Get the union set of the n-best among multiple criterion
|
||||
nbests = set().union(
|
||||
*[
|
||||
set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
|
||||
for ph, k, m in trainer_options.best_model_criterion
|
||||
if reporter.has(ph, k)
|
||||
]
|
||||
)
|
||||
|
||||
# Generated n-best averaged model
|
||||
if (
|
||||
trainer_options.nbest_averaging_interval > 0
|
||||
and iepoch % trainer_options.nbest_averaging_interval == 0
|
||||
):
|
||||
average_nbest_models(
|
||||
reporter=reporter,
|
||||
output_dir=output_dir,
|
||||
best_model_criterion=trainer_options.best_model_criterion,
|
||||
nbest=keep_nbest_models,
|
||||
suffix=f"till{iepoch}epoch",
|
||||
oss_bucket=trainer_options.oss_bucket,
|
||||
pai_output_dir=trainer_options.output_dir,
|
||||
)
|
||||
|
||||
for e in range(1, iepoch):
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
_removed.append(str(p))
|
||||
else:
|
||||
p = output_dir / f"{e}epoch.pb"
|
||||
if p.exists() and e not in nbests:
|
||||
p.unlink()
|
||||
_removed.append(str(p))
|
||||
if len(_removed) != 0:
|
||||
logging.info("The model files were removed: " + ", ".join(_removed))
|
||||
|
||||
# 7. If any updating haven't happened, stops the training
|
||||
if all_steps_are_invalid:
|
||||
logging.warning(
|
||||
f"The gradients at all steps are invalid in this epoch. "
|
||||
f"Something seems wrong. This training was stopped at {iepoch}epoch"
|
||||
)
|
||||
break
|
||||
|
||||
if max_update_stop:
|
||||
logging.info(
|
||||
f"Stopping training due to "
|
||||
f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
|
||||
)
|
||||
break
|
||||
|
||||
# 8. Check early stopping
|
||||
if trainer_options.patience is not None:
|
||||
if reporter.check_early_stopping(
|
||||
trainer_options.patience, *trainer_options.early_stopping_criterion
|
||||
):
|
||||
break
|
||||
|
||||
else:
|
||||
logging.info(
|
||||
f"The training was finished at {trainer_options.max_epoch} epochs "
|
||||
)
|
||||
|
||||
# Generated n-best averaged model
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
average_nbest_models(
|
||||
reporter=reporter,
|
||||
output_dir=output_dir,
|
||||
best_model_criterion=trainer_options.best_model_criterion,
|
||||
nbest=keep_nbest_models,
|
||||
oss_bucket=trainer_options.oss_bucket,
|
||||
pai_output_dir=trainer_options.output_dir,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def train_one_epoch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
reporter: SubReporter,
|
||||
summary_writer,
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> Tuple[bool, bool]:
|
||||
assert check_argument_types()
|
||||
|
||||
grad_noise = options.grad_noise
|
||||
accum_grad = options.accum_grad
|
||||
grad_clip = options.grad_clip
|
||||
grad_clip_type = options.grad_clip_type
|
||||
log_interval = options.log_interval
|
||||
no_forward_run = options.no_forward_run
|
||||
ngpu = options.ngpu
|
||||
use_wandb = options.use_wandb
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
if log_interval is None:
|
||||
try:
|
||||
log_interval = max(len(iterator) // 20, 10)
|
||||
except TypeError:
|
||||
log_interval = 100
|
||||
|
||||
model.train()
|
||||
all_steps_are_invalid = True
|
||||
max_update_stop = False
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
||||
|
||||
# get the rank
|
||||
rank = distributed_option.dist_rank
|
||||
# get the num batch updates
|
||||
num_batch_updates = 0
|
||||
# ouput dir
|
||||
output_dir = Path(options.output_dir)
|
||||
# batch interval
|
||||
batch_interval = options.batch_interval
|
||||
assert batch_interval > 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
for iiter, (_, batch) in enumerate(
|
||||
reporter.measure_iter_time(iterator, "iter_time"), 1
|
||||
):
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
|
||||
if rank == 0:
|
||||
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
||||
num_batch_updates = model.get_num_updates() if hasattr(model,
|
||||
"num_updates") else model.module.get_num_updates()
|
||||
if (num_batch_updates % batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
|
||||
buffer = BytesIO()
|
||||
torch.save(model.state_dict(), buffer)
|
||||
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"),
|
||||
buffer.getvalue())
|
||||
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
all_steps_are_invalid = False
|
||||
continue
|
||||
|
||||
with autocast(scaler is not None):
|
||||
with reporter.measure_time("forward_time"):
|
||||
retval = model(**batch)
|
||||
|
||||
# Note(kamo):
|
||||
# Supporting two patterns for the returned value from the model
|
||||
# a. dict type
|
||||
if isinstance(retval, dict):
|
||||
loss = retval["loss"]
|
||||
stats = retval["stats"]
|
||||
weight = retval["weight"]
|
||||
optim_idx = retval.get("optim_idx")
|
||||
if optim_idx is not None and not isinstance(optim_idx, int):
|
||||
if not isinstance(optim_idx, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
"optim_idx must be int or 1dim torch.Tensor, "
|
||||
f"but got {type(optim_idx)}"
|
||||
)
|
||||
if optim_idx.dim() >= 2:
|
||||
raise RuntimeError(
|
||||
"optim_idx must be int or 1dim torch.Tensor, "
|
||||
f"but got {optim_idx.dim()}dim tensor"
|
||||
)
|
||||
if optim_idx.dim() == 1:
|
||||
for v in optim_idx:
|
||||
if v != optim_idx[0]:
|
||||
raise RuntimeError(
|
||||
"optim_idx must be 1dim tensor "
|
||||
"having same values for all entries"
|
||||
)
|
||||
optim_idx = optim_idx[0].item()
|
||||
else:
|
||||
optim_idx = optim_idx.item()
|
||||
|
||||
# b. tuple or list type
|
||||
else:
|
||||
loss, stats, weight = retval
|
||||
optim_idx = None
|
||||
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if ngpu > 1 or distributed:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed)
|
||||
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
if distributed:
|
||||
# NOTE(kamo): Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= torch.distributed.get_world_size()
|
||||
|
||||
loss /= accum_grad
|
||||
|
||||
reporter.register(stats, weight)
|
||||
|
||||
with reporter.measure_time("backward_time"):
|
||||
if scaler is not None:
|
||||
# Scales loss. Calls backward() on scaled loss
|
||||
# to create scaled gradients.
|
||||
# Backward passes under autocast are not recommended.
|
||||
# Backward ops run in the same dtype autocast chose
|
||||
# for corresponding forward ops.
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if iiter % accum_grad == 0:
|
||||
if scaler is not None:
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
scaler.unscale_(optimizer)
|
||||
|
||||
# gradient noise injection
|
||||
if grad_noise:
|
||||
add_gradient_noise(
|
||||
model,
|
||||
reporter.get_total_count(),
|
||||
duration=100,
|
||||
eta=1.0,
|
||||
scale_factor=0.55,
|
||||
)
|
||||
|
||||
# compute the gradient norm to check if it is normal or not
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=grad_clip,
|
||||
norm_type=grad_clip_type,
|
||||
)
|
||||
# PyTorch<=1.4, clip_grad_norm_ returns float value
|
||||
if not isinstance(grad_norm, torch.Tensor):
|
||||
grad_norm = torch.tensor(grad_norm)
|
||||
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
|
||||
# Must invoke scaler.update() if unscale_() is used in the iteration
|
||||
# to avoid the following error:
|
||||
# RuntimeError: unscale_() has already been called
|
||||
# on this optimizer since the last update().
|
||||
# Note that if the gradient has inf/nan values,
|
||||
# scaler.step skips optimizer.step().
|
||||
if scaler is not None:
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
all_steps_are_invalid = False
|
||||
with reporter.measure_time("optim_step_time"):
|
||||
for iopt, (optimizer, scheduler) in enumerate(
|
||||
zip(optimizers, schedulers)
|
||||
):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
if scaler is not None:
|
||||
# scaler.step() first unscales the gradients of
|
||||
# the optimizer's assigned params.
|
||||
scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
if isinstance(scheduler, AbsBatchStepScheduler):
|
||||
scheduler.step()
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Register lr and train/load time[sec/step],
|
||||
# where step refers to accum_grad * mini-batch
|
||||
reporter.register(
|
||||
dict(
|
||||
{
|
||||
f"optim{i}_lr{j}": pg["lr"]
|
||||
for i, optimizer in enumerate(optimizers)
|
||||
for j, pg in enumerate(optimizer.param_groups)
|
||||
if "lr" in pg
|
||||
},
|
||||
train_time=time.perf_counter() - start_time,
|
||||
),
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# update num_updates
|
||||
if distributed:
|
||||
if hasattr(model.module, "num_updates"):
|
||||
model.module.set_num_updates(model.module.get_num_updates() + 1)
|
||||
options.num_updates = model.module.get_num_updates()
|
||||
if model.module.get_num_updates() >= options.max_update:
|
||||
max_update_stop = True
|
||||
else:
|
||||
if hasattr(model, "num_updates"):
|
||||
model.set_num_updates(model.get_num_updates() + 1)
|
||||
options.num_updates = model.get_num_updates()
|
||||
if model.get_num_updates() >= options.max_update:
|
||||
max_update_stop = True
|
||||
|
||||
# NOTE(kamo): Call log_message() after next()
|
||||
reporter.next()
|
||||
if iiter % log_interval == 0:
|
||||
num_updates = options.num_updates if hasattr(options, "num_updates") else None
|
||||
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
|
||||
if summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
|
||||
if use_wandb:
|
||||
reporter.wandb_log()
|
||||
|
||||
if max_update_stop:
|
||||
break
|
||||
|
||||
else:
|
||||
if distributed:
|
||||
iterator_stop.fill_(1)
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
return all_steps_are_invalid, max_update_stop
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def validate_one_epoch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
iterator: Iterable[Dict[str, torch.Tensor]],
|
||||
reporter: SubReporter,
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> None:
|
||||
assert check_argument_types()
|
||||
ngpu = options.ngpu
|
||||
no_forward_run = options.no_forward_run
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
model.eval()
|
||||
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
||||
for (_, batch) in iterator:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
continue
|
||||
|
||||
retval = model(**batch)
|
||||
if isinstance(retval, dict):
|
||||
stats = retval["stats"]
|
||||
weight = retval["weight"]
|
||||
else:
|
||||
_, stats, weight = retval
|
||||
if ngpu > 1 or distributed:
|
||||
# Apply weighted averaging for stats.
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed)
|
||||
|
||||
reporter.register(stats, weight)
|
||||
reporter.next()
|
||||
|
||||
else:
|
||||
if distributed:
|
||||
iterator_stop.fill_(1)
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
Loading…
Reference in New Issue
Block a user