mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
532 lines
23 KiB
Python
532 lines
23 KiB
Python
import math
|
|
import os
|
|
import time
|
|
import torch
|
|
import logging
|
|
from tqdm import tqdm
|
|
from datetime import datetime
|
|
import torch.distributed as dist
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
from contextlib import nullcontext, contextmanager
|
|
from pathlib import Path
|
|
|
|
from funasr.train_utils.device_funcs import to_device
|
|
from funasr.train_utils.recursive_op import recursive_average
|
|
from funasr.train_utils.average_nbest_models import average_checkpoints
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
|
@contextmanager
|
|
def maybe_autocast(enabled):
|
|
if enabled:
|
|
with autocast():
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
class Trainer:
|
|
"""
|
|
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
|
and optionally resuming from a saved checkpoint.
|
|
|
|
Attributes:
|
|
max_epoch (int): Maximum number of epochs for training.
|
|
model (torch.nn.Module): The model to be trained.
|
|
optim (torch.optim.Optimizer): The optimizer to use for training.
|
|
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
|
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
|
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
|
output_dir (str): Directory where model checkpoints will be saved.
|
|
resume (str, optional): Path to a checkpoint to resume training from.
|
|
"""
|
|
|
|
def __init__(self,
|
|
local_rank,
|
|
use_ddp: bool = False,
|
|
use_fsdp: bool = False,
|
|
use_fp16: bool = False,
|
|
output_dir: str="./",
|
|
**kwargs):
|
|
"""
|
|
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to be trained.
|
|
optim (torch.optim.Optimizer): The optimizer to use for training.
|
|
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
|
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
|
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
|
**kwargs: Additional keyword arguments:
|
|
max_epoch (int): The maximum number of epochs for training.
|
|
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
|
resume (str, optional): The file path to a checkpoint to resume training from.
|
|
"""
|
|
|
|
self.output_dir = output_dir
|
|
if not os.path.exists(self.output_dir):
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
self.resume = kwargs.get('resume', True)
|
|
self.start_epoch = 0
|
|
self.max_epoch = kwargs.get('max_epoch', 100)
|
|
self.local_rank = local_rank
|
|
self.use_ddp = use_ddp
|
|
self.use_fsdp = use_fsdp
|
|
self.device = kwargs.get('device', "cuda")
|
|
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
|
|
# self.kwargs = kwargs
|
|
self.log_interval = kwargs.get("log_interval", 50)
|
|
self.batch_total = 0
|
|
self.use_fp16 = use_fp16
|
|
self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
|
|
# scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
|
|
# scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
|
|
# self.scaler = scaler
|
|
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
|
self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
|
|
self.accum_grad = kwargs.get("accum_grad", 1)
|
|
self.grad_clip = kwargs.get("grad_clip", 10.0)
|
|
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
|
|
self.validate_interval = kwargs.get("validate_interval", 5000)
|
|
|
|
|
|
try:
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
except:
|
|
rank = 0
|
|
world_size = 1
|
|
logging.warning("distributed is not initialized, only single shard")
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.train_acc_avg = 0.0
|
|
self.train_loss_avg = 0.0
|
|
self.val_acc_avg = 0.0
|
|
self.val_loss_avg = 0.0
|
|
self.best_acc_idx = 0
|
|
self.saved_ckpts = {}
|
|
self.val_acc_list = []
|
|
self.step_or_epoch = -1
|
|
|
|
def save_checkpoint(self, epoch,
|
|
step=None,
|
|
model=None,
|
|
optim=None,
|
|
scheduler=None,
|
|
scaler=None,
|
|
):
|
|
"""
|
|
Saves a checkpoint containing the model's state, the optimizer's state,
|
|
and the scheduler's state at the end of the given epoch. This method is
|
|
intended to be called at the end of each epoch to save the training progress.
|
|
|
|
Args:
|
|
epoch (int): The epoch number at which the checkpoint is being saved.
|
|
"""
|
|
|
|
if self.rank == 0:
|
|
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
|
|
self.step_or_epoch += 1
|
|
state = {
|
|
'epoch': epoch,
|
|
'state_dict': model.state_dict(),
|
|
'optimizer': optim.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
"acc": self.val_acc_list,
|
|
"step_or_epoch": self.step_or_epoch,
|
|
}
|
|
if hasattr(model, "module"):
|
|
state["state_dict"] = model.module.state_dict()
|
|
|
|
if scaler:
|
|
state["scaler_state"] = scaler.state_dict()
|
|
# Create output directory if it does not exist
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
if step is None:
|
|
ckpt_name = f'model.pt.ep{epoch}'
|
|
else:
|
|
ckpt_name = f'model.pt.ep{epoch}.{step}'
|
|
filename = os.path.join(self.output_dir, ckpt_name)
|
|
torch.save(state, filename)
|
|
|
|
logging.info(f'\nCheckpoint saved to {filename}\n')
|
|
latest = Path(os.path.join(self.output_dir, f'model.pt'))
|
|
torch.save(state, latest)
|
|
|
|
if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
|
|
self.best_acc_idx = self.step_or_epoch
|
|
best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
|
|
torch.save(state, best_ckpt)
|
|
logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
|
|
else:
|
|
logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
|
|
|
|
if self.keep_nbest_models > 0:
|
|
self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
|
|
if len(self.saved_ckpts) > self.keep_nbest_models:
|
|
|
|
min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
|
if min_key in self.saved_ckpts:
|
|
del self.saved_ckpts[min_key]
|
|
filename = os.path.join(self.output_dir, min_key)
|
|
logging.info(f"Delete: {filename}")
|
|
if os.path.exists(filename):
|
|
os.remove(filename)
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
|
|
def resume_checkpoint(self,
|
|
model=None,
|
|
optim=None,
|
|
scheduler=None,
|
|
scaler=None,
|
|
):
|
|
"""
|
|
Resumes training from a checkpoint at the given file path.
|
|
Loads the model's state, the optimizer's state, and the scheduler's state.
|
|
|
|
Args:
|
|
resume_path (str): The file path to the checkpoint to resume from.
|
|
"""
|
|
if self.resume:
|
|
ckpt = os.path.join(self.output_dir, "model.pt")
|
|
if os.path.isfile(ckpt):
|
|
checkpoint = torch.load(ckpt)
|
|
self.start_epoch = checkpoint['epoch'] + 1
|
|
# self.model.load_state_dict(checkpoint['state_dict'])
|
|
src_state = checkpoint['state_dict']
|
|
dst_state = model.state_dict()
|
|
for k in dst_state.keys():
|
|
if not k.startswith("module.") and "module."+k in src_state.keys():
|
|
k_ddp = "module."+k
|
|
elif k.startswith("module.") and "module."+k not in src_state.keys():
|
|
k_ddp = k.replace("module.", "", 1)
|
|
else:
|
|
k_ddp = k
|
|
if k_ddp in src_state.keys():
|
|
dst_state[k] = src_state[k_ddp]
|
|
else:
|
|
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
|
|
|
model.load_state_dict(dst_state)
|
|
optim.load_state_dict(checkpoint['optimizer'])
|
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
|
if scaler is not None and 'scaler_state' in checkpoint:
|
|
scaler.load_state_dict(checkpoint['scaler_state'])
|
|
|
|
self.val_acc_list = checkpoint["acc"]
|
|
self.step_or_epoch = checkpoint["step_or_epoch"]
|
|
|
|
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
|
else:
|
|
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
|
|
|
|
def train_epoch(self,
|
|
model=None,
|
|
optim=None,
|
|
scheduler=None,
|
|
scaler=None,
|
|
dataloader_train=None,
|
|
dataloader_val=None,
|
|
epoch=None,
|
|
writer=None,
|
|
):
|
|
"""
|
|
Defines the training process for a single epoch with gradient accumulation.
|
|
Args:
|
|
epoch (int): The current epoch number.
|
|
"""
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
|
|
model.train()
|
|
|
|
# Set the number of steps for gradient accumulation
|
|
accum_grad = self.accum_grad
|
|
# Initialize the gradient accumulation
|
|
optim.zero_grad()
|
|
speed_stats = {}
|
|
time5 = time.perf_counter()
|
|
iterator_stop = torch.tensor(0).to(self.device)
|
|
|
|
dataloader_train.batch_sampler.set_epoch(epoch)
|
|
for batch_idx, batch in enumerate(dataloader_train):
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
|
if iterator_stop > 0:
|
|
break
|
|
self.batch_total += 1
|
|
time1 = time.perf_counter()
|
|
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
|
|
|
batch = to_device(batch, self.device)
|
|
|
|
my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
|
with my_context():
|
|
time2 = time.perf_counter()
|
|
with maybe_autocast(self.use_fp16):
|
|
retval = model(**batch)
|
|
|
|
time3 = time.perf_counter()
|
|
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
|
loss, stats, weight = retval
|
|
stats = {k: v for k, v in stats.items() if v is not None}
|
|
if self.use_ddp or self.use_fsdp:
|
|
# Apply weighted averaging for loss and stats
|
|
loss = (loss * weight.type(loss.dtype)).sum()
|
|
# if distributed, this method can also apply all_reduce()
|
|
# stats, weight = recursive_average(stats, weight, distributed=True)
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
|
# Now weight is summation over all workers
|
|
loss /= weight.sum() # shape:[1] -> shape:[]
|
|
# Multiply world_size because DistributedDataParallel
|
|
# automatically normalizes the gradient by world_size.
|
|
loss *= self.world_size
|
|
# Scale the loss since we're not updating for every mini-batch
|
|
loss = loss / accum_grad
|
|
if self.use_fp16:
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
time4 = time.perf_counter()
|
|
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
|
|
|
|
self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
|
|
if "acc" in stats:
|
|
self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
|
|
if self.use_ddp or self.use_fsdp:
|
|
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
|
|
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
|
|
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
|
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
|
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
|
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
|
|
|
|
|
# Perform an optimizer step only after accumulating enough gradients
|
|
if (batch_idx + 1) % accum_grad == 0:
|
|
# Perform gradient clipping if it is set
|
|
if self.grad_clip > 0:
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(),
|
|
max_norm=self.grad_clip,
|
|
norm_type=self.grad_clip_type,
|
|
)
|
|
if not torch.isfinite(grad_norm):
|
|
logging.warning(
|
|
f"The grad norm is {grad_norm}. Skipping updating the model."
|
|
)
|
|
optim.zero_grad() # Reset gradients
|
|
continue
|
|
|
|
# Execute an optimization step (update model parameters)
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
if self.use_fp16:
|
|
scaler.step(optim)
|
|
scaler.update()
|
|
else:
|
|
optim.step()
|
|
scheduler.step()
|
|
# Clear gradients for the next accumulation stage
|
|
optim.zero_grad(set_to_none=True)
|
|
total_time = f"{time.perf_counter() - time5:0.3f}"
|
|
time5 = time.perf_counter()
|
|
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
|
|
|
speed_stats["total_time"] = total_time
|
|
lr = scheduler.get_last_lr()[0]
|
|
batch_num_epoch = 1
|
|
if hasattr(dataloader_train, "__len__"):
|
|
batch_num_epoch = len(dataloader_train)
|
|
self.log(epoch, batch_idx,
|
|
batch_num_epoch=batch_num_epoch,
|
|
lr=lr,
|
|
loss=loss.detach().cpu().item(),
|
|
speed_stats=speed_stats,
|
|
stats=stats,
|
|
writer=writer,
|
|
tag="train",
|
|
)
|
|
|
|
if (batch_idx + 1) % self.validate_interval == 0:
|
|
self.validate_epoch(
|
|
model=model,
|
|
dataloader_val=dataloader_val,
|
|
epoch=epoch,
|
|
writer=writer
|
|
)
|
|
|
|
if (batch_idx+1) % self.save_checkpoint_interval == 0:
|
|
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
|
|
|
|
else:
|
|
if self.use_ddp or self.use_fsdp:
|
|
iterator_stop.fill_(1)
|
|
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
|
|
iterator_stop = torch.tensor(0).to(self.device)
|
|
|
|
|
|
|
|
def validate_epoch(self,
|
|
model=None,
|
|
dataloader_val=None,
|
|
epoch=None,
|
|
writer=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Defines the validation process for a single epoch.
|
|
Should be implemented with the actual model validation steps.
|
|
|
|
Args:
|
|
epoch (int): The current epoch number.
|
|
"""
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
speed_stats = {}
|
|
time5 = time.perf_counter()
|
|
iterator_stop = torch.tensor(0).to(self.device)
|
|
dataloader_val.batch_sampler.set_epoch(epoch)
|
|
for batch_idx, batch in enumerate(dataloader_val):
|
|
# if self.use_ddp or self.use_fsdp:
|
|
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
|
# if epoch >= 1:
|
|
# print(f"iterator_stop: {iterator_stop}\n")
|
|
# if iterator_stop > 0:
|
|
# break
|
|
time1 = time.perf_counter()
|
|
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
|
batch = to_device(batch, self.device)
|
|
time2 = time.perf_counter()
|
|
retval = model(**batch)
|
|
time3 = time.perf_counter()
|
|
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
|
loss, stats, weight = retval
|
|
stats = {k: v for k, v in stats.items() if v is not None}
|
|
if self.use_ddp or self.use_fsdp:
|
|
# Apply weighted averaging for loss and stats
|
|
loss = (loss * weight.type(loss.dtype)).sum()
|
|
# if distributed, this method can also apply all_reduce()
|
|
# stats, weight = recursive_average(stats, weight, distributed=True)
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
|
# Now weight is summation over all workers
|
|
loss /= weight.sum() # shape:[1] -> shape:[]
|
|
# Multiply world_size because DistributedDataParallel
|
|
# automatically normalizes the gradient by world_size.
|
|
loss *= self.world_size
|
|
# Scale the loss since we're not updating for every mini-batch
|
|
loss = loss
|
|
time4 = time.perf_counter()
|
|
|
|
self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
|
|
if "acc" in stats:
|
|
self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
|
|
if self.use_ddp or self.use_fsdp:
|
|
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
|
|
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
|
|
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
|
|
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
|
|
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
|
|
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
|
|
|
|
batch_num_epoch = 1
|
|
if hasattr(dataloader_val, "__len__"):
|
|
batch_num_epoch = len(dataloader_val)
|
|
self.log(epoch, batch_idx,
|
|
batch_num_epoch=batch_num_epoch,
|
|
lr=0.0,
|
|
loss=loss.detach().cpu().item(),
|
|
speed_stats=speed_stats,
|
|
stats=stats,
|
|
writer=writer,
|
|
tag="val",
|
|
)
|
|
|
|
else:
|
|
if self.use_ddp or self.use_fsdp:
|
|
iterator_stop.fill_(1)
|
|
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
|
|
|
self.val_acc_list.append(self.val_acc_avg)
|
|
model.train()
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
iterator_stop = torch.tensor(0).to(self.device)
|
|
|
|
|
|
def log(self,
|
|
epoch=0,
|
|
batch_idx=0,
|
|
batch_num_epoch=-1,
|
|
lr=0.0,
|
|
loss=0.0,
|
|
speed_stats=None,
|
|
stats=None,
|
|
writer=None,
|
|
tag="train",
|
|
):
|
|
|
|
if (batch_idx + 1) % self.log_interval == 0:
|
|
|
|
gpu_info = "GPU, memory: usage: {:.3f} GB, " \
|
|
"peak: {:.3f} GB, " \
|
|
"cache: {:.3f} GB, " \
|
|
"cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
|
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
|
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
|
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
|
|
)
|
|
|
|
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
|
|
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
|
|
description = (
|
|
f"{tag}, "
|
|
f"rank: {self.local_rank}, "
|
|
f"epoch: {epoch}/{self.max_epoch}, "
|
|
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
|
|
f"(loss_avg_rank: {loss:.3f}), "
|
|
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
|
|
f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
|
|
f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
|
|
f"(lr: {lr:.3e}), "
|
|
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
|
|
f"{speed_stats}, "
|
|
f"{gpu_info}"
|
|
)
|
|
logging.info(description)
|
|
|
|
if writer is not None:
|
|
writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
|
|
writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
|
|
writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
|
|
for key, var in stats.items():
|
|
writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
|
|
for key, var in speed_stats.items():
|
|
writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
|
|
|
|
def close(self, writer=None):
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
dist.barrier()
|
|
|
|
if writer is not None:
|
|
writer.close()
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
torch.distributed.destroy_process_group() |