mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf llm (#1503)
* update * update * update * update onnx * update with main (#1492) * contextual&seaco ONNX export (#1481) * contextual&seaco ONNX export * update ContextualEmbedderExport2 * update ContextualEmbedderExport2 * update code * onnx (#1482) * qwenaudio qwenaudiochat * qwenaudio qwenaudiochat * whisper * whisper * llm * llm * llm * llm * llm * llm * llm * llm * export onnx * export onnx * export onnx * dingding * dingding * llm * doc * onnx * onnx * onnx * onnx * onnx * onnx * v1.0.15 * qwenaudio * qwenaudio * issue doc * update * update * bugfix * onnx * update export calling * update codes * remove useless code * update code --------- Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com> * acknowledge --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> * update onnx * update onnx * train update * train update * train update * train update --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
This commit is contained in:
parent
a2d6575d89
commit
5023dd0422
@ -7,10 +7,10 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model_revision="v2.0.4",
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
punc_model_revision="v2.0.4",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="v2.0.4",
|
||||
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# punc_model_revision="v2.0.4",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model_revision="v2.0.2",
|
||||
)
|
||||
|
||||
@ -8,8 +8,14 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/Whisper-large-v3",
|
||||
model_revision="v2.0.4",
|
||||
model_revision="v2.0.5",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
)
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", language=None)
|
||||
res = model.generate(
|
||||
language=None,
|
||||
task="transcribe",
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
|
||||
print(res)
|
||||
|
||||
@ -10,10 +10,11 @@ from funasr import AutoModel
|
||||
# model = AutoModel(model="Whisper-small", hub="openai")
|
||||
# model = AutoModel(model="Whisper-medium", hub="openai")
|
||||
# model = AutoModel(model="Whisper-large-v2", hub="openai")
|
||||
model = AutoModel(model="Whisper-large-v3", hub="openai")
|
||||
model = AutoModel(model="Whisper-large-v3", hub="openai", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",)
|
||||
|
||||
res = model.generate(
|
||||
language=None,
|
||||
task="transcribe",
|
||||
batch_size_s=0,
|
||||
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
print(res)
|
||||
|
||||
@ -291,7 +291,7 @@ class AutoModel:
|
||||
# step.2 compute asr model
|
||||
model = self.model
|
||||
deep_update(kwargs, cfg)
|
||||
batch_size = int(kwargs.get("batch_size_s", 300))*1000
|
||||
batch_size = max(int(kwargs.get("batch_size_s", 300))*1000, 1)
|
||||
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
|
||||
kwargs["batch_size"] = batch_size
|
||||
|
||||
|
||||
@ -6,17 +6,22 @@ import sys
|
||||
import torch
|
||||
import hydra
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
|
||||
import torch.distributed as dist
|
||||
from collections.abc import Sequence
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.optimizers import optim_classes
|
||||
from funasr.train_utils.trainer import Trainer
|
||||
from funasr.train_utils.trainer_llm import Trainer
|
||||
from funasr.schedulers import scheduler_classes
|
||||
from funasr.train_utils.initialize import initialize
|
||||
from funasr.download.download_from_hub import download_model
|
||||
@ -61,14 +66,9 @@ def main(**kwargs):
|
||||
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
device = kwargs.get("device", "cpu")
|
||||
device = kwargs.get("device", "cuda")
|
||||
kwargs["device"] = "cpu"
|
||||
model = AutoModel(**kwargs)
|
||||
kwargs["device"] = device
|
||||
model = model.model
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
frontend = kwargs["frontend"]
|
||||
|
||||
|
||||
|
||||
# save config.yaml
|
||||
@ -77,35 +77,14 @@ def main(**kwargs):
|
||||
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
||||
OmegaConf.save(config=kwargs, f=yaml_file)
|
||||
logging.info("config.yaml is saved to: %s", yaml_file)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# init_param
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
if not isinstance(init_param, (list, tuple)):
|
||||
init_param = (init_param,)
|
||||
logging.info("init_param is not None: %s", init_param)
|
||||
for p in init_param:
|
||||
if os.path.exists(p):
|
||||
logging.info(f"Loading pretrained params from {p}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
path=p,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
scope_map=kwargs.get("scope_map", []),
|
||||
excludes=kwargs.get("excludes", None),
|
||||
)
|
||||
else:
|
||||
logging.info(f"Checkpoint does not exist, init randomly: {p}")
|
||||
elif kwargs.get("init", None):
|
||||
initialize(model, kwargs.get("init", "kaiming_normal"))
|
||||
else:
|
||||
print("No initialize method")
|
||||
|
||||
# parse kwargs
|
||||
kwargs = model.kwargs
|
||||
kwargs["device"] = device
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
frontend = kwargs["frontend"]
|
||||
model = model.model
|
||||
del kwargs["model"]
|
||||
|
||||
# freeze_param
|
||||
freeze_param = kwargs.get("freeze_param", None)
|
||||
@ -129,7 +108,8 @@ def main(**kwargs):
|
||||
model = FSDP(model).cuda(local_rank)
|
||||
else:
|
||||
model = model.to(device=kwargs.get("device", "cuda"))
|
||||
|
||||
|
||||
kwargs["device"] = next(model.parameters()).device
|
||||
|
||||
# optim
|
||||
optim = kwargs.get("optim", "adam")
|
||||
@ -156,34 +136,68 @@ def main(**kwargs):
|
||||
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||
batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||
collate_fn=dataset_tr.collator,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||
pin_memory=True)
|
||||
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
|
||||
dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
|
||||
|
||||
trainer = Trainer(local_rank=local_rank,
|
||||
use_ddp=use_ddp,
|
||||
resume=kwargs.get("resume", True),
|
||||
device=kwargs["device"],
|
||||
**kwargs.get("train_conf"),
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
|
||||
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
|
||||
|
||||
trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
|
||||
|
||||
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
|
||||
except:
|
||||
writer = None
|
||||
|
||||
dataloader_val = torch.utils.data.DataLoader(dataset_val,
|
||||
collate_fn=dataset_val.collator,
|
||||
batch_sampler=batch_sampler_val,
|
||||
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||
pin_memory=True)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
dataloader_train=dataloader_tr,
|
||||
dataloader_val=dataloader_val,
|
||||
local_rank=local_rank,
|
||||
use_ddp=use_ddp,
|
||||
use_fsdp=use_fsdp,
|
||||
output_dir=kwargs.get("output_dir", "./exp"),
|
||||
resume=kwargs.get("resume", True),
|
||||
**kwargs.get("train_conf"),
|
||||
)
|
||||
trainer.run()
|
||||
|
||||
if use_ddp or use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
|
||||
time1 = time.perf_counter()
|
||||
trainer.train_epoch(
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
dataloader_train=dataloader_tr,
|
||||
dataloader_val=dataloader_val,
|
||||
epoch=epoch,
|
||||
writer=writer
|
||||
)
|
||||
|
||||
trainer.validate_epoch(
|
||||
model=model,
|
||||
dataloader_val=dataloader_val,
|
||||
epoch=epoch,
|
||||
writer=writer
|
||||
)
|
||||
|
||||
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
time2 = time.perf_counter()
|
||||
time_escaped = (time2 - time1) / 3600.0
|
||||
logging.info(
|
||||
f"\nrank: {local_rank}, "
|
||||
f"time_escaped_epoch: {time_escaped:.3f} hours, "
|
||||
f"estimated to finish {trainer.max_epoch} "
|
||||
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
|
||||
|
||||
|
||||
if trainer.rank == 0:
|
||||
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
|
||||
|
||||
trainer.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
62
funasr/datasets/audio_datasets/jsonl2scp.py
Normal file
62
funasr/datasets/audio_datasets/jsonl2scp.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
|
||||
def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
|
||||
|
||||
wav_f = open(wav_scp_file, "w")
|
||||
text_f = open(text_file, "w")
|
||||
with open(jsonl_file, encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
prompt = data.get("prompt", "<ASR>")
|
||||
source = data[data_type_list[0]]
|
||||
target = data[data_type_list[1]]
|
||||
source_len = data.get("source_len", 1)
|
||||
target_len = data.get("target_len", 0)
|
||||
if "aishell" in source:
|
||||
target = target.replace(" ", "")
|
||||
key = data["key"]
|
||||
wav_f.write(f"{key}\t{source}\n")
|
||||
wav_f.flush()
|
||||
text_f.write(f"{key}\t{target}\n")
|
||||
text_f.flush()
|
||||
|
||||
wav_f.close()
|
||||
text_f.close()
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(cfg: DictConfig):
|
||||
|
||||
kwargs = OmegaConf.to_container(cfg, resolve=True)
|
||||
|
||||
scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
|
||||
if isinstance(scp_file_list, str):
|
||||
scp_file_list = eval(scp_file_list)
|
||||
data_type_list = kwargs.get("data_type_list", ("source", "target"))
|
||||
jsonl_file = kwargs.get("jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
|
||||
gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
|
||||
|
||||
|
||||
"""
|
||||
python -m funasr.datasets.audio_datasets.json2scp \
|
||||
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
|
||||
++data_type_list='["source", "target"]' \
|
||||
++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
|
||||
|
||||
@ -142,9 +142,9 @@ class DistributedSamplerWarp(BatchSampler):
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
|
||||
def CustomDistributedBatchSampler_fn(dataset, **kwargs):
|
||||
dataloader_args = {"dataset": dataset}
|
||||
dataloader_args = {}
|
||||
dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
|
||||
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
|
||||
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
|
||||
|
||||
@ -264,7 +264,7 @@ class LLMASRNAR(nn.Module):
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=None)
|
||||
if len(kwargs.get("data_type")) > 1:
|
||||
if len(kwargs.get("data_type", [])) > 1:
|
||||
audio_sample_list, text_token_int_list = audio_sample_list
|
||||
text_token_int = text_token_int_list[0].replace(" ", "")
|
||||
text_token_int = tokenizer.encode(text_token_int)
|
||||
@ -561,7 +561,7 @@ class LLMASRNARPrompt(nn.Module):
|
||||
audio_mask = kwargs.get("audio_mask", None)
|
||||
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
|
||||
text_token_int = kwargs.get("text_token_int", None)
|
||||
if audio_token_lengths is None:
|
||||
if audio_token_lengths is None and text_token_int is not None:
|
||||
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
@ -572,7 +572,9 @@ class LLMASRNARPrompt(nn.Module):
|
||||
mask=enc_mask,
|
||||
target_label_length=audio_token_lengths,
|
||||
)
|
||||
loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
|
||||
loss_pre = 0.0
|
||||
if audio_token_lengths is not None:
|
||||
loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, loss_pre
|
||||
|
||||
@ -603,10 +605,12 @@ class LLMASRNARPrompt(nn.Module):
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=None)
|
||||
if len(kwargs.get("data_type")) > 1:
|
||||
if len(kwargs.get("data_type", [])) > 1:
|
||||
audio_sample_list, text_token_int_list = audio_sample_list
|
||||
text_token_int = text_token_int_list[0].replace(" ", "")
|
||||
text_token_int = text_token_int_list[0]
|
||||
text_token_int = tokenizer.encode(text_token_int)
|
||||
if text_token_int[0] == tokenizer.bos_token_id:
|
||||
text_token_int = text_token_int[1:]
|
||||
else:
|
||||
text_token_int = None
|
||||
time2 = time.perf_counter()
|
||||
@ -621,24 +625,30 @@ class LLMASRNARPrompt(nn.Module):
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
|
||||
res = self.encode(speech, speech_lengths, text_token_int=text_token_int)
|
||||
encoder_out = res[0]
|
||||
|
||||
# adaptor
|
||||
encoder_out = self.adaptor(encoder_out)
|
||||
|
||||
prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
|
||||
prompt_ids = tokenizer.encode(prompt_pre)
|
||||
if prompt_ids[0] == tokenizer.bos_token_id:
|
||||
prompt_ids = prompt_ids[1:]
|
||||
# prompt_ids = prompt_ids + [tokenizer.pad_token_id]
|
||||
prompt_length = len(prompt_ids)
|
||||
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
|
||||
pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"])
|
||||
|
||||
if hasattr(self.llm.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
|
||||
pad = self.llm.model.embed_tokens(pad)
|
||||
elif hasattr(self.llm.model.model, "embed_tokens"):
|
||||
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
|
||||
else:
|
||||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
|
||||
|
||||
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
|
||||
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio]
|
||||
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
|
||||
|
||||
# model_outputs = self.llm.generate(
|
||||
@ -662,8 +672,11 @@ class LLMASRNARPrompt(nn.Module):
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
|
||||
|
||||
text = text[0].split(': ')[-1]
|
||||
text = text[0].split(':')[-1]
|
||||
text = text.strip()
|
||||
if text.startswith("Please\n "):
|
||||
text = text.replace("Please\n ", "")
|
||||
text = text.strip()
|
||||
|
||||
# preds = torch.argmax(model_outputs.logits, -1)
|
||||
|
||||
|
||||
462
funasr/train_utils/trainer_llm.py
Normal file
462
funasr/train_utils/trainer_llm.py
Normal file
@ -0,0 +1,462 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
import torch.distributed as dist
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
|
||||
@contextmanager
|
||||
def maybe_autocast(enabled):
|
||||
if enabled:
|
||||
with autocast():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
and optionally resuming from a saved checkpoint.
|
||||
|
||||
Attributes:
|
||||
max_epoch (int): Maximum number of epochs for training.
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
||||
output_dir (str): Directory where model checkpoints will be saved.
|
||||
resume (str, optional): Path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_rank,
|
||||
use_ddp: bool = False,
|
||||
use_fsdp: bool = False,
|
||||
use_fp16: bool = False,
|
||||
output_dir: str="./",
|
||||
**kwargs):
|
||||
"""
|
||||
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
||||
**kwargs: Additional keyword arguments:
|
||||
max_epoch (int): The maximum number of epochs for training.
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
self.output_dir = output_dir
|
||||
self.resume = kwargs.get('resume', True)
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get('max_epoch', 100)
|
||||
self.local_rank = local_rank
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.device = kwargs.get('device', "cuda")
|
||||
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
|
||||
# self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
self.use_fp16 = use_fp16
|
||||
self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
|
||||
# scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
|
||||
# scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
|
||||
# self.scaler = scaler
|
||||
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
||||
self.accum_grad = kwargs.get("accum_grad", 1)
|
||||
self.grad_clip = kwargs.get("grad_clip", 10.0)
|
||||
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
|
||||
self.validate_interval = kwargs.get("validate_interval", 5000)
|
||||
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
|
||||
|
||||
|
||||
def save_checkpoint(self, epoch,
|
||||
step=None,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
and the scheduler's state at the end of the given epoch. This method is
|
||||
intended to be called at the end of each epoch to save the training progress.
|
||||
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optim.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
}
|
||||
if scaler:
|
||||
state["scaler_state"] = scaler.state_dict()
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
if step is None:
|
||||
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
|
||||
else:
|
||||
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
|
||||
|
||||
torch.save(state, filename)
|
||||
|
||||
print(f'\nCheckpoint saved to {filename}\n')
|
||||
latest = Path(os.path.join(self.output_dir, f'model.pt'))
|
||||
torch.save(state, latest)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
def resume_checkpoint(self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
):
|
||||
"""
|
||||
Resumes training from a checkpoint at the given file path.
|
||||
Loads the model's state, the optimizer's state, and the scheduler's state.
|
||||
|
||||
Args:
|
||||
resume_path (str): The file path to the checkpoint to resume from.
|
||||
"""
|
||||
if self.resume:
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
checkpoint = torch.load(ckpt)
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
# self.model.load_state_dict(checkpoint['state_dict'])
|
||||
src_state = checkpoint['state_dict']
|
||||
dst_state = model.state_dict()
|
||||
for k in dst_state.keys():
|
||||
if not k.startswith("module.") and "module."+k in src_state.keys():
|
||||
k_ddp = "module."+k
|
||||
else:
|
||||
k_ddp = k
|
||||
if k_ddp in src_state.keys():
|
||||
dst_state[k] = src_state[k_ddp]
|
||||
else:
|
||||
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
||||
|
||||
model.load_state_dict(dst_state)
|
||||
optim.load_state_dict(checkpoint['optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
if scaler is not None and 'scaler_state' in checkpoint:
|
||||
scaler.load_state_dict(checkpoint['scaler_state'])
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
# def train(self):
|
||||
# """
|
||||
# Starts the training process, iterating over epochs, training the model,
|
||||
# and saving checkpoints at the end of each epoch.
|
||||
# """
|
||||
# if self.resume:
|
||||
# self.resume_checkpoint(self.output_dir)
|
||||
#
|
||||
# for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||
# time1 = time.perf_counter()
|
||||
# self.train_epoch(epoch)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.barrier()
|
||||
#
|
||||
# self._validate_epoch(epoch)
|
||||
#
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.barrier()
|
||||
#
|
||||
#
|
||||
# if self.rank == 0:
|
||||
# self._save_checkpoint(epoch)
|
||||
#
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.barrier()
|
||||
#
|
||||
# self.scheduler.step()
|
||||
#
|
||||
# time2 = time.perf_counter()
|
||||
# time_escaped = (time2 - time1)/3600.0
|
||||
# print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
|
||||
#
|
||||
# if self.rank == 0:
|
||||
# average_checkpoints(self.output_dir, self.avg_nbest_model)
|
||||
#
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.barrier()
|
||||
#
|
||||
#
|
||||
# if writer:
|
||||
# writer.close()
|
||||
#
|
||||
|
||||
def train_epoch(self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
dataloader_train=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
writer=None,
|
||||
):
|
||||
"""
|
||||
Defines the training process for a single epoch with gradient accumulation.
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accum_grad = self.accum_grad
|
||||
# Initialize the gradient accumulation
|
||||
optim.zero_grad()
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader_train):
|
||||
self.batch_total += 1
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
with maybe_autocast(self.use_fp16):
|
||||
retval = model(**batch)
|
||||
|
||||
if self.disable_gpu_cache: torch.cuda.empty_cache()
|
||||
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
if self.use_fp16:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
time4 = time.perf_counter()
|
||||
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accum_grad == 0:
|
||||
# Perform gradient clipping if it is set
|
||||
if self.grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=self.grad_clip,
|
||||
norm_type=self.grad_clip_type,
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
optim.zero_grad() # Reset gradients
|
||||
continue
|
||||
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
if self.use_fp16:
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
optim.zero_grad(set_to_none=True)
|
||||
total_time = f"{time.perf_counter() - time5:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
speed_stats["total_time"] = total_time
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
self.log(epoch, batch_idx,
|
||||
batch_num_epoch=len(dataloader_train),
|
||||
lr=lr,
|
||||
loss=loss.detach().cpu().item(),
|
||||
speed_stats=speed_stats,
|
||||
stats=stats,
|
||||
writer=writer,
|
||||
tag="train",
|
||||
)
|
||||
|
||||
if (batch_idx + 1) % self.validate_interval == 0:
|
||||
self.validate_epoch(
|
||||
model=model,
|
||||
dataloader_val=dataloader_val,
|
||||
epoch=epoch,
|
||||
writer=writer
|
||||
)
|
||||
|
||||
if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
|
||||
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
|
||||
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
||||
def validate_epoch(self,
|
||||
model=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
writer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Defines the validation process for a single epoch.
|
||||
Should be implemented with the actual model validation steps.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
for batch_idx, batch in enumerate(dataloader_val):
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
||||
batch = to_device(batch, self.device)
|
||||
time2 = time.perf_counter()
|
||||
retval = model(**batch)
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss
|
||||
time4 = time.perf_counter()
|
||||
|
||||
|
||||
self.log(epoch, batch_idx,
|
||||
batch_num_epoch=len(dataloader_val),
|
||||
lr=0.0,
|
||||
loss=loss.detach().cpu().item(),
|
||||
speed_stats=speed_stats,
|
||||
stats=stats,
|
||||
writer=writer,
|
||||
tag="train",
|
||||
)
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
def log(self,
|
||||
epoch=0,
|
||||
batch_idx=0,
|
||||
batch_num_epoch=-1,
|
||||
lr=0.0,
|
||||
loss=0.0,
|
||||
speed_stats=None,
|
||||
stats=None,
|
||||
writer=None,
|
||||
tag="train",
|
||||
):
|
||||
|
||||
if (batch_idx + 1) % self.log_interval == 0:
|
||||
|
||||
gpu_info = "GPU, memory: {:.3f} GB, " \
|
||||
"{:.3f} GB, " \
|
||||
"{:.3f} GB, " \
|
||||
"{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
|
||||
)
|
||||
|
||||
time_now = datetime.now()
|
||||
time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
description = (
|
||||
f"{time_now}, "
|
||||
f"rank: {self.local_rank}, "
|
||||
f"epoch: {epoch}/{self.max_epoch}, "
|
||||
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
|
||||
f"(loss: {loss:.3f}), "
|
||||
f"(lr: {lr:.3e}), "
|
||||
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
|
||||
f"{speed_stats}, "
|
||||
f"{gpu_info}"
|
||||
)
|
||||
logging.info(description)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
|
||||
writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
|
||||
for key, var in stats.items():
|
||||
writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
|
||||
for key, var in speed_stats.items():
|
||||
writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
|
||||
|
||||
def close(self, writer=None):
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
Loading…
Reference in New Issue
Block a user