From c652f6814ac62eebb5fd1a55a303ee9110c87b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sun, 23 Apr 2023 17:30:38 +0800 Subject: [PATCH] update --- funasr/bin/train.py | 8 +- funasr/build_utils/build_optimizer.py | 4 +- funasr/build_utils/build_scheduler.py | 25 +- funasr/build_utils/build_trainer.py | 843 ++++++++++++++++++++++++++ 4 files changed, 869 insertions(+), 11 deletions(-) create mode 100644 funasr/build_utils/build_trainer.py diff --git a/funasr/bin/train.py b/funasr/bin/train.py index e861199cd..c32a362ad 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -420,16 +420,16 @@ if __name__ == '__main__': prepare_data(args, distributed_option) model = build_model(args) - optimizer = build_optimizer(args, model=model) - scheduler = build_scheduler(args, optimizer) + optimizers = build_optimizer(args, model=model) + schedulers = build_scheduler(args, optimizers) logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, distributed_option.dist_rank, distributed_option.local_rank)) logging.info(pytorch_cudnn_version()) logging.info(model_summary(model)) - logging.info("Optimizer: {}".format(optimizer)) - logging.info("Scheduler: {}".format(scheduler)) + logging.info("Optimizer: {}".format(optimizers)) + logging.info("Scheduler: {}".format(schedulers)) # dump args to config.yaml if not distributed_option.distributed or distributed_option.dist_rank == 0: diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py index 3b2799492..bd0b73d88 100644 --- a/funasr/build_utils/build_optimizer.py +++ b/funasr/build_utils/build_optimizer.py @@ -23,4 +23,6 @@ def build_optimizer(args, model): if optim_class is None: raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") optimizer = optim_class(model.parameters(), **args.optim_conf) - return optimizer \ No newline at end of file + + optimizers = [optimizer] + return optimizers \ No newline at end of file diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py index f0e6d1f54..4b9990e6e 100644 --- a/funasr/build_utils/build_scheduler.py +++ b/funasr/build_utils/build_scheduler.py @@ -8,7 +8,7 @@ from funasr.schedulers.tri_stage_scheduler import TriStageLR from funasr.schedulers.warmup_lr import WarmupLR -def build_scheduler(args, optimizer): +def build_scheduler(args, optimizers): scheduler_classes = dict( ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, lambdalr=torch.optim.lr_scheduler.LambdaLR, @@ -24,8 +24,21 @@ def build_scheduler(args, optimizer): CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, ) - scheduler_class = scheduler_classes.get(args.scheduler) - if scheduler_class is None: - raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}") - scheduler = scheduler_class(optimizer, **args.scheduler_conf) - return scheduler \ No newline at end of file + schedulers = [] + for i, optim in enumerate(optimizers, 1): + suf = "" if i == 1 else str(i) + name = getattr(args, f"scheduler{suf}") + conf = getattr(args, f"scheduler{suf}_conf") + if name is not None: + cls_ = scheduler_classes.get(name) + if cls_ is None: + raise ValueError( + f"must be one of {list(scheduler_classes)}: {name}" + ) + scheduler = cls_(optim, **conf) + else: + scheduler = None + + schedulers.append(scheduler) + + return schedulers \ No newline at end of file diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py new file mode 100644 index 000000000..8e4ee467c --- /dev/null +++ b/funasr/build_utils/build_trainer.py @@ -0,0 +1,843 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Trainer module.""" +import argparse +from contextlib import contextmanager +import dataclasses +from dataclasses import is_dataclass +from distutils.version import LooseVersion +import logging +from pathlib import Path +import time +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import humanfriendly +import oss2 +from io import BytesIO +import os +import numpy as np +import torch +import torch.nn +import torch.optim +from typeguard import check_argument_types + +from funasr.iterators.abs_iter_factory import AbsIterFactory +from funasr.main_funcs.average_nbest_models import average_nbest_models +from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions +from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler +from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler +from funasr.schedulers.abs_scheduler import AbsScheduler +from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler +from funasr.torch_utils.add_gradient_noise import add_gradient_noise +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.recursive_op import recursive_average +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.models.base_model import FunASRModel +from funasr.train.distributed_utils import DistributedOption +from funasr.train.reporter import Reporter +from funasr.train.reporter import SubReporter +from funasr.utils.build_dataclass import build_dataclass + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast + from torch.cuda.amp import GradScaler +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + + GradScaler = None + +try: + import fairscale +except ImportError: + fairscale = None + + +@dataclasses.dataclass +class TrainerOptions: + ngpu: int + resume: bool + use_amp: bool + train_dtype: str + grad_noise: bool + accum_grad: int + grad_clip: float + grad_clip_type: float + log_interval: Optional[int] + no_forward_run: bool + use_tensorboard: bool + use_wandb: bool + output_dir: Union[Path, str] + max_epoch: int + max_update: int + seed: int + sharded_ddp: bool + patience: Optional[int] + keep_nbest_models: Union[int, List[int]] + nbest_averaging_interval: int + early_stopping_criterion: Sequence[str] + best_model_criterion: Sequence[Sequence[str]] + val_scheduler_criterion: Sequence[str] + unused_parameters: bool + wandb_model_log_interval: int + use_pai: bool + oss_bucket: Union[oss2.Bucket, None] + batch_interval: int + + +class Trainer: + """Trainer having a optimizer. + + If you'd like to use multiple optimizers, then inherit this class + and override the methods if necessary - at least "train_one_epoch()" + + >>> class TwoOptimizerTrainer(Trainer): + ... @classmethod + ... def add_arguments(cls, parser): + ... ... + ... + ... @classmethod + ... def train_one_epoch(cls, model, optimizers, ...): + ... loss1 = model.model1(...) + ... loss1.backward() + ... optimizers[0].step() + ... + ... loss2 = model.model2(...) + ... loss2.backward() + ... optimizers[1].step() + + """ + + def __init__(self): + raise RuntimeError("This class can't be instantiated.") + + @classmethod + def build_options(cls, args: argparse.Namespace) -> TrainerOptions: + """Build options consumed by train(), eval()""" + assert check_argument_types() + return build_dataclass(TrainerOptions, args) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Reserved for future development of another Trainer""" + pass + + @staticmethod + def resume( + checkpoint: Union[str, Path], + model: torch.nn.Module, + reporter: Reporter, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + scaler: Optional[GradScaler], + ngpu: int = 0, + ): + states = torch.load( + checkpoint, + map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", + ) + model.load_state_dict(states["model"]) + reporter.load_state_dict(states["reporter"]) + for optimizer, state in zip(optimizers, states["optimizers"]): + optimizer.load_state_dict(state) + for scheduler, state in zip(schedulers, states["schedulers"]): + if scheduler is not None: + scheduler.load_state_dict(state) + if scaler is not None: + if states["scaler"] is None: + logging.warning("scaler state is not found") + else: + scaler.load_state_dict(states["scaler"]) + + logging.info(f"The training was resumed using {checkpoint}") + + @classmethod + def run( + cls, + model: FunASRModel, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + train_iter_factory: AbsIterFactory, + valid_iter_factory: AbsIterFactory, + trainer_options, + distributed_option: DistributedOption, + ) -> None: + """Perform training. This method performs the main process of training.""" + assert check_argument_types() + # NOTE(kamo): Don't check the type more strictly as far trainer_options + assert is_dataclass(trainer_options), type(trainer_options) + assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) + + if isinstance(trainer_options.keep_nbest_models, int): + keep_nbest_models = [trainer_options.keep_nbest_models] + else: + if len(trainer_options.keep_nbest_models) == 0: + logging.warning("No keep_nbest_models is given. Change to [1]") + trainer_options.keep_nbest_models = [1] + keep_nbest_models = trainer_options.keep_nbest_models + + # assert batch_interval is set and >0 + assert trainer_options.batch_interval > 0 + + output_dir = Path(trainer_options.output_dir) + reporter = Reporter() + if trainer_options.use_amp: + if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): + raise RuntimeError( + "Require torch>=1.6.0 for Automatic Mixed Precision" + ) + if trainer_options.sharded_ddp: + if fairscale is None: + raise RuntimeError( + "Requiring fairscale. Do 'pip install fairscale'" + ) + scaler = fairscale.optim.grad_scaler.ShardedGradScaler() + else: + scaler = GradScaler() + else: + scaler = None + + if trainer_options.resume and (output_dir / "checkpoint.pb").exists(): + cls.resume( + checkpoint=output_dir / "checkpoint.pb", + model=model, + optimizers=optimizers, + schedulers=schedulers, + reporter=reporter, + scaler=scaler, + ngpu=trainer_options.ngpu, + ) + + start_epoch = reporter.get_epoch() + 1 + if start_epoch == trainer_options.max_epoch + 1: + logging.warning( + f"The training has already reached at max_epoch: {start_epoch}" + ) + + if distributed_option.distributed: + if trainer_options.sharded_ddp: + dp_model = fairscale.nn.data_parallel.ShardedDataParallel( + module=model, + sharded_optimizer=optimizers, + ) + else: + dp_model = torch.nn.parallel.DistributedDataParallel( + model, find_unused_parameters=trainer_options.unused_parameters) + elif distributed_option.ngpu > 1: + dp_model = torch.nn.parallel.DataParallel( + model, + device_ids=list(range(distributed_option.ngpu)), + ) + else: + # NOTE(kamo): DataParallel also should work with ngpu=1, + # but for debuggability it's better to keep this block. + dp_model = model + + if trainer_options.use_tensorboard and ( + not distributed_option.distributed or distributed_option.dist_rank == 0 + ): + from torch.utils.tensorboard import SummaryWriter + if trainer_options.use_pai: + train_summary_writer = SummaryWriter( + os.path.join(trainer_options.output_dir, "tensorboard/train") + ) + valid_summary_writer = SummaryWriter( + os.path.join(trainer_options.output_dir, "tensorboard/valid") + ) + else: + train_summary_writer = SummaryWriter( + str(output_dir / "tensorboard" / "train") + ) + valid_summary_writer = SummaryWriter( + str(output_dir / "tensorboard" / "valid") + ) + else: + train_summary_writer = None + + start_time = time.perf_counter() + for iepoch in range(start_epoch, trainer_options.max_epoch + 1): + if iepoch != start_epoch: + logging.info( + "{}/{}epoch started. Estimated time to finish: {}".format( + iepoch, + trainer_options.max_epoch, + humanfriendly.format_timespan( + (time.perf_counter() - start_time) + / (iepoch - start_epoch) + * (trainer_options.max_epoch - iepoch + 1) + ), + ) + ) + else: + logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started") + set_all_random_seed(trainer_options.seed + iepoch) + + reporter.set_epoch(iepoch) + # 1. Train and validation for one-epoch + with reporter.observe("train") as sub_reporter: + all_steps_are_invalid, max_update_stop = cls.train_one_epoch( + model=dp_model, + optimizers=optimizers, + schedulers=schedulers, + iterator=train_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + scaler=scaler, + summary_writer=train_summary_writer, + options=trainer_options, + distributed_option=distributed_option, + ) + + with reporter.observe("valid") as sub_reporter: + cls.validate_one_epoch( + model=dp_model, + iterator=valid_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + options=trainer_options, + distributed_option=distributed_option, + ) + + # 2. LR Scheduler step + for scheduler in schedulers: + if isinstance(scheduler, AbsValEpochStepScheduler): + scheduler.step( + reporter.get_value(*trainer_options.val_scheduler_criterion) + ) + elif isinstance(scheduler, AbsEpochStepScheduler): + scheduler.step() + if trainer_options.sharded_ddp: + for optimizer in optimizers: + if isinstance(optimizer, fairscale.optim.oss.OSS): + optimizer.consolidate_state_dict() + + if not distributed_option.distributed or distributed_option.dist_rank == 0: + # 3. Report the results + logging.info(reporter.log_message()) + if train_summary_writer is not None: + reporter.tensorboard_add_scalar(train_summary_writer, key1="train") + reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid") + if trainer_options.use_wandb: + reporter.wandb_log() + + # save tensorboard on oss + if trainer_options.use_pai and train_summary_writer is not None: + def write_tensorboard_summary(summary_writer_path, oss_bucket): + file_list = [] + for root, dirs, files in os.walk(summary_writer_path, topdown=False): + for name in files: + file_full_path = os.path.join(root, name) + file_list.append(file_full_path) + + for file_full_path in file_list: + with open(file_full_path, "rb") as f: + oss_bucket.put_object(file_full_path, f) + + write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), + trainer_options.oss_bucket) + write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), + trainer_options.oss_bucket) + + # 4. Save/Update the checkpoint + if trainer_options.use_pai: + buffer = BytesIO() + torch.save( + { + "model": model.state_dict(), + "reporter": reporter.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "schedulers": [ + s.state_dict() if s is not None else None + for s in schedulers + ], + "scaler": scaler.state_dict() if scaler is not None else None, + "ema_model": model.encoder.ema.model.state_dict() + if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None, + }, + buffer, + ) + trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), + buffer.getvalue()) + else: + torch.save( + { + "model": model.state_dict(), + "reporter": reporter.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "schedulers": [ + s.state_dict() if s is not None else None + for s in schedulers + ], + "scaler": scaler.state_dict() if scaler is not None else None, + }, + output_dir / "checkpoint.pb", + ) + + # 5. Save and log the model and update the link to the best model + if trainer_options.use_pai: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, + f"{iepoch}epoch.pb"), buffer.getvalue()) + else: + torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb") + + # Creates a sym link latest.pb -> {iepoch}epoch.pb + if trainer_options.use_pai: + p = os.path.join(trainer_options.output_dir, "latest.pb") + if trainer_options.oss_bucket.object_exists(p): + trainer_options.oss_bucket.delete_object(p) + trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name, + os.path.join(trainer_options.output_dir, + f"{iepoch}epoch.pb"), p) + else: + p = output_dir / "latest.pb" + if p.is_symlink() or p.exists(): + p.unlink() + p.symlink_to(f"{iepoch}epoch.pb") + + _improved = [] + for _phase, k, _mode in trainer_options.best_model_criterion: + # e.g. _phase, k, _mode = "train", "loss", "min" + if reporter.has(_phase, k): + best_epoch = reporter.get_best_epoch(_phase, k, _mode) + # Creates sym links if it's the best result + if best_epoch == iepoch: + if trainer_options.use_pai: + p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb") + if trainer_options.oss_bucket.object_exists(p): + trainer_options.oss_bucket.delete_object(p) + trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name, + os.path.join(trainer_options.output_dir, + f"{iepoch}epoch.pb"), p) + else: + p = output_dir / f"{_phase}.{k}.best.pb" + if p.is_symlink() or p.exists(): + p.unlink() + p.symlink_to(f"{iepoch}epoch.pb") + _improved.append(f"{_phase}.{k}") + if len(_improved) == 0: + logging.info("There are no improvements in this epoch") + else: + logging.info( + "The best model has been updated: " + ", ".join(_improved) + ) + + log_model = ( + trainer_options.wandb_model_log_interval > 0 + and iepoch % trainer_options.wandb_model_log_interval == 0 + ) + if log_model and trainer_options.use_wandb: + import wandb + + logging.info("Logging Model on this epoch :::::") + artifact = wandb.Artifact( + name=f"model_{wandb.run.id}", + type="model", + metadata={"improved": _improved}, + ) + artifact.add_file(str(output_dir / f"{iepoch}epoch.pb")) + aliases = [ + f"epoch-{iepoch}", + "best" if best_epoch == iepoch else "", + ] + wandb.log_artifact(artifact, aliases=aliases) + + # 6. Remove the model files excluding n-best epoch and latest epoch + _removed = [] + # Get the union set of the n-best among multiple criterion + nbests = set().union( + *[ + set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)]) + for ph, k, m in trainer_options.best_model_criterion + if reporter.has(ph, k) + ] + ) + + # Generated n-best averaged model + if ( + trainer_options.nbest_averaging_interval > 0 + and iepoch % trainer_options.nbest_averaging_interval == 0 + ): + average_nbest_models( + reporter=reporter, + output_dir=output_dir, + best_model_criterion=trainer_options.best_model_criterion, + nbest=keep_nbest_models, + suffix=f"till{iepoch}epoch", + oss_bucket=trainer_options.oss_bucket, + pai_output_dir=trainer_options.output_dir, + ) + + for e in range(1, iepoch): + if trainer_options.use_pai: + p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb") + if trainer_options.oss_bucket.object_exists(p) and e not in nbests: + trainer_options.oss_bucket.delete_object(p) + _removed.append(str(p)) + else: + p = output_dir / f"{e}epoch.pb" + if p.exists() and e not in nbests: + p.unlink() + _removed.append(str(p)) + if len(_removed) != 0: + logging.info("The model files were removed: " + ", ".join(_removed)) + + # 7. If any updating haven't happened, stops the training + if all_steps_are_invalid: + logging.warning( + f"The gradients at all steps are invalid in this epoch. " + f"Something seems wrong. This training was stopped at {iepoch}epoch" + ) + break + + if max_update_stop: + logging.info( + f"Stopping training due to " + f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}" + ) + break + + # 8. Check early stopping + if trainer_options.patience is not None: + if reporter.check_early_stopping( + trainer_options.patience, *trainer_options.early_stopping_criterion + ): + break + + else: + logging.info( + f"The training was finished at {trainer_options.max_epoch} epochs " + ) + + # Generated n-best averaged model + if not distributed_option.distributed or distributed_option.dist_rank == 0: + average_nbest_models( + reporter=reporter, + output_dir=output_dir, + best_model_criterion=trainer_options.best_model_criterion, + nbest=keep_nbest_models, + oss_bucket=trainer_options.oss_bucket, + pai_output_dir=trainer_options.output_dir, + ) + + @classmethod + def train_one_epoch( + cls, + model: torch.nn.Module, + iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + scaler: Optional[GradScaler], + reporter: SubReporter, + summary_writer, + options: TrainerOptions, + distributed_option: DistributedOption, + ) -> Tuple[bool, bool]: + assert check_argument_types() + + grad_noise = options.grad_noise + accum_grad = options.accum_grad + grad_clip = options.grad_clip + grad_clip_type = options.grad_clip_type + log_interval = options.log_interval + no_forward_run = options.no_forward_run + ngpu = options.ngpu + use_wandb = options.use_wandb + distributed = distributed_option.distributed + + if log_interval is None: + try: + log_interval = max(len(iterator) // 20, 10) + except TypeError: + log_interval = 100 + + model.train() + all_steps_are_invalid = True + max_update_stop = False + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") + + # get the rank + rank = distributed_option.dist_rank + # get the num batch updates + num_batch_updates = 0 + # ouput dir + output_dir = Path(options.output_dir) + # batch interval + batch_interval = options.batch_interval + assert batch_interval > 0 + + start_time = time.perf_counter() + for iiter, (_, batch) in enumerate( + reporter.measure_iter_time(iterator, "iter_time"), 1 + ): + assert isinstance(batch, dict), type(batch) + + if rank == 0: + if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): + num_batch_updates = model.get_num_updates() if hasattr(model, + "num_updates") else model.module.get_num_updates() + if (num_batch_updates % batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: + buffer = BytesIO() + torch.save(model.state_dict(), buffer) + options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), + buffer.getvalue()) + + if distributed: + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + if no_forward_run: + all_steps_are_invalid = False + continue + + with autocast(scaler is not None): + with reporter.measure_time("forward_time"): + retval = model(**batch) + + # Note(kamo): + # Supporting two patterns for the returned value from the model + # a. dict type + if isinstance(retval, dict): + loss = retval["loss"] + stats = retval["stats"] + weight = retval["weight"] + optim_idx = retval.get("optim_idx") + if optim_idx is not None and not isinstance(optim_idx, int): + if not isinstance(optim_idx, torch.Tensor): + raise RuntimeError( + "optim_idx must be int or 1dim torch.Tensor, " + f"but got {type(optim_idx)}" + ) + if optim_idx.dim() >= 2: + raise RuntimeError( + "optim_idx must be int or 1dim torch.Tensor, " + f"but got {optim_idx.dim()}dim tensor" + ) + if optim_idx.dim() == 1: + for v in optim_idx: + if v != optim_idx[0]: + raise RuntimeError( + "optim_idx must be 1dim tensor " + "having same values for all entries" + ) + optim_idx = optim_idx[0].item() + else: + optim_idx = optim_idx.item() + + # b. tuple or list type + else: + loss, stats, weight = retval + optim_idx = None + + stats = {k: v for k, v in stats.items() if v is not None} + if ngpu > 1 or distributed: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed) + + # Now weight is summation over all workers + loss /= weight + if distributed: + # NOTE(kamo): Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= torch.distributed.get_world_size() + + loss /= accum_grad + + reporter.register(stats, weight) + + with reporter.measure_time("backward_time"): + if scaler is not None: + # Scales loss. Calls backward() on scaled loss + # to create scaled gradients. + # Backward passes under autocast are not recommended. + # Backward ops run in the same dtype autocast chose + # for corresponding forward ops. + scaler.scale(loss).backward() + else: + loss.backward() + + if iiter % accum_grad == 0: + if scaler is not None: + # Unscales the gradients of optimizer's assigned params in-place + for iopt, optimizer in enumerate(optimizers): + if optim_idx is not None and iopt != optim_idx: + continue + scaler.unscale_(optimizer) + + # gradient noise injection + if grad_noise: + add_gradient_noise( + model, + reporter.get_total_count(), + duration=100, + eta=1.0, + scale_factor=0.55, + ) + + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + max_norm=grad_clip, + norm_type=grad_clip_type, + ) + # PyTorch<=1.4, clip_grad_norm_ returns float value + if not isinstance(grad_norm, torch.Tensor): + grad_norm = torch.tensor(grad_norm) + + if not torch.isfinite(grad_norm): + logging.warning( + f"The grad norm is {grad_norm}. Skipping updating the model." + ) + + # Must invoke scaler.update() if unscale_() is used in the iteration + # to avoid the following error: + # RuntimeError: unscale_() has already been called + # on this optimizer since the last update(). + # Note that if the gradient has inf/nan values, + # scaler.step skips optimizer.step(). + if scaler is not None: + for iopt, optimizer in enumerate(optimizers): + if optim_idx is not None and iopt != optim_idx: + continue + scaler.step(optimizer) + scaler.update() + + else: + all_steps_are_invalid = False + with reporter.measure_time("optim_step_time"): + for iopt, (optimizer, scheduler) in enumerate( + zip(optimizers, schedulers) + ): + if optim_idx is not None and iopt != optim_idx: + continue + if scaler is not None: + # scaler.step() first unscales the gradients of + # the optimizer's assigned params. + scaler.step(optimizer) + # Updates the scale for next iteration. + scaler.update() + else: + optimizer.step() + if isinstance(scheduler, AbsBatchStepScheduler): + scheduler.step() + for iopt, optimizer in enumerate(optimizers): + if optim_idx is not None and iopt != optim_idx: + continue + optimizer.zero_grad() + + # Register lr and train/load time[sec/step], + # where step refers to accum_grad * mini-batch + reporter.register( + dict( + { + f"optim{i}_lr{j}": pg["lr"] + for i, optimizer in enumerate(optimizers) + for j, pg in enumerate(optimizer.param_groups) + if "lr" in pg + }, + train_time=time.perf_counter() - start_time, + ), + ) + start_time = time.perf_counter() + + # update num_updates + if distributed: + if hasattr(model.module, "num_updates"): + model.module.set_num_updates(model.module.get_num_updates() + 1) + options.num_updates = model.module.get_num_updates() + if model.module.get_num_updates() >= options.max_update: + max_update_stop = True + else: + if hasattr(model, "num_updates"): + model.set_num_updates(model.get_num_updates() + 1) + options.num_updates = model.get_num_updates() + if model.get_num_updates() >= options.max_update: + max_update_stop = True + + # NOTE(kamo): Call log_message() after next() + reporter.next() + if iiter % log_interval == 0: + num_updates = options.num_updates if hasattr(options, "num_updates") else None + logging.info(reporter.log_message(-log_interval, num_updates=num_updates)) + if summary_writer is not None: + reporter.tensorboard_add_scalar(summary_writer, -log_interval) + if use_wandb: + reporter.wandb_log() + + if max_update_stop: + break + + else: + if distributed: + iterator_stop.fill_(1) + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + return all_steps_are_invalid, max_update_stop + + @classmethod + @torch.no_grad() + def validate_one_epoch( + cls, + model: torch.nn.Module, + iterator: Iterable[Dict[str, torch.Tensor]], + reporter: SubReporter, + options: TrainerOptions, + distributed_option: DistributedOption, + ) -> None: + assert check_argument_types() + ngpu = options.ngpu + no_forward_run = options.no_forward_run + distributed = distributed_option.distributed + + model.eval() + + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") + for (_, batch) in iterator: + assert isinstance(batch, dict), type(batch) + if distributed: + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + if no_forward_run: + continue + + retval = model(**batch) + if isinstance(retval, dict): + stats = retval["stats"] + weight = retval["weight"] + else: + _, stats, weight = retval + if ngpu > 1 or distributed: + # Apply weighted averaging for stats. + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed) + + reporter.register(stats, weight) + reporter.next() + + else: + if distributed: + iterator_stop.fill_(1) + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)