From 5023dd04224eddd4c9a047bd946695c3932743ae Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Fri, 15 Mar 2024 16:24:29 +0800 Subject: [PATCH] Dev gzf llm (#1503) * update * update * update * update onnx * update with main (#1492) * contextual&seaco ONNX export (#1481) * contextual&seaco ONNX export * update ContextualEmbedderExport2 * update ContextualEmbedderExport2 * update code * onnx (#1482) * qwenaudio qwenaudiochat * qwenaudio qwenaudiochat * whisper * whisper * llm * llm * llm * llm * llm * llm * llm * llm * export onnx * export onnx * export onnx * dingding * dingding * llm * doc * onnx * onnx * onnx * onnx * onnx * onnx * v1.0.15 * qwenaudio * qwenaudio * issue doc * update * update * bugfix * onnx * update export calling * update codes * remove useless code * update code --------- Co-authored-by: zhifu gao * acknowledge --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> * update onnx * update onnx * train update * train update * train update * train update --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> --- .../paraformer/demo.py | 8 +- .../whisper/demo.py | 10 +- .../whisper/demo_from_openai.py | 3 +- funasr/auto/auto_model.py | 2 +- funasr/bin/train_llm.py | 140 +++--- funasr/datasets/audio_datasets/jsonl2scp.py | 62 +++ .../datasets/llm_datasets_vicuna/samplers.py | 4 +- funasr/models/llm_asr_nar/model.py | 29 +- funasr/train_utils/trainer_llm.py | 462 ++++++++++++++++++ 9 files changed, 639 insertions(+), 81 deletions(-) create mode 100644 funasr/datasets/audio_datasets/jsonl2scp.py create mode 100644 funasr/train_utils/trainer_llm.py diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index 0265b123e..651df1e54 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -7,10 +7,10 @@ from funasr import AutoModel model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", - vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.4", - punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.4", + # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + # vad_model_revision="v2.0.4", + # punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + # punc_model_revision="v2.0.4", # spk_model="iic/speech_campplus_sv_zh-cn_16k-common", # spk_model_revision="v2.0.2", ) diff --git a/examples/industrial_data_pretraining/whisper/demo.py b/examples/industrial_data_pretraining/whisper/demo.py index 01e125d51..ddebbdfe7 100644 --- a/examples/industrial_data_pretraining/whisper/demo.py +++ b/examples/industrial_data_pretraining/whisper/demo.py @@ -8,8 +8,14 @@ from funasr import AutoModel model = AutoModel(model="iic/Whisper-large-v3", - model_revision="v2.0.4", + model_revision="v2.0.5", + vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", ) -res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", language=None) +res = model.generate( + language=None, + task="transcribe", + batch_size_s=0, + input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") + print(res) diff --git a/examples/industrial_data_pretraining/whisper/demo_from_openai.py b/examples/industrial_data_pretraining/whisper/demo_from_openai.py index 2ee8ad53d..5cac06b5d 100644 --- a/examples/industrial_data_pretraining/whisper/demo_from_openai.py +++ b/examples/industrial_data_pretraining/whisper/demo_from_openai.py @@ -10,10 +10,11 @@ from funasr import AutoModel # model = AutoModel(model="Whisper-small", hub="openai") # model = AutoModel(model="Whisper-medium", hub="openai") # model = AutoModel(model="Whisper-large-v2", hub="openai") -model = AutoModel(model="Whisper-large-v3", hub="openai") +model = AutoModel(model="Whisper-large-v3", hub="openai", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",) res = model.generate( language=None, task="transcribe", + batch_size_s=0, input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") print(res) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 2df191096..8c847c548 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -291,7 +291,7 @@ class AutoModel: # step.2 compute asr model model = self.model deep_update(kwargs, cfg) - batch_size = int(kwargs.get("batch_size_s", 300))*1000 + batch_size = max(int(kwargs.get("batch_size_s", 300))*1000, 1) batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000 kwargs["batch_size"] = batch_size diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py index a33cd5336..8742bf14f 100644 --- a/funasr/bin/train_llm.py +++ b/funasr/bin/train_llm.py @@ -6,17 +6,22 @@ import sys import torch import hydra import logging +import time import argparse from io import BytesIO + import torch.distributed as dist from collections.abc import Sequence from omegaconf import DictConfig, OmegaConf +from torch.cuda.amp import autocast, GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from funasr.train_utils.average_nbest_models import average_checkpoints from funasr.register import tables from funasr.optimizers import optim_classes -from funasr.train_utils.trainer import Trainer +from funasr.train_utils.trainer_llm import Trainer from funasr.schedulers import scheduler_classes from funasr.train_utils.initialize import initialize from funasr.download.download_from_hub import download_model @@ -61,14 +66,9 @@ def main(**kwargs): dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://') torch.cuda.set_device(local_rank) - device = kwargs.get("device", "cpu") + device = kwargs.get("device", "cuda") kwargs["device"] = "cpu" model = AutoModel(**kwargs) - kwargs["device"] = device - model = model.model - tokenizer = kwargs["tokenizer"] - frontend = kwargs["frontend"] - # save config.yaml @@ -77,35 +77,14 @@ def main(**kwargs): yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") OmegaConf.save(config=kwargs, f=yaml_file) logging.info("config.yaml is saved to: %s", yaml_file) - - - - - # init_param - init_param = kwargs.get("init_param", None) - if init_param is not None: - if not isinstance(init_param, (list, tuple)): - init_param = (init_param,) - logging.info("init_param is not None: %s", init_param) - for p in init_param: - if os.path.exists(p): - logging.info(f"Loading pretrained params from {p}") - load_pretrained_model( - model=model, - path=p, - ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), - oss_bucket=kwargs.get("oss_bucket", None), - scope_map=kwargs.get("scope_map", []), - excludes=kwargs.get("excludes", None), - ) - else: - logging.info(f"Checkpoint does not exist, init randomly: {p}") - elif kwargs.get("init", None): - initialize(model, kwargs.get("init", "kaiming_normal")) - else: - print("No initialize method") - + # parse kwargs + kwargs = model.kwargs + kwargs["device"] = device + tokenizer = kwargs["tokenizer"] + frontend = kwargs["frontend"] + model = model.model + del kwargs["model"] # freeze_param freeze_param = kwargs.get("freeze_param", None) @@ -129,7 +108,8 @@ def main(**kwargs): model = FSDP(model).cuda(local_rank) else: model = model.to(device=kwargs.get("device", "cuda")) - + + kwargs["device"] = next(model.parameters()).device # optim optim = kwargs.get("optim", "adam") @@ -156,34 +136,68 @@ def main(**kwargs): batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) - dataloader_tr = torch.utils.data.DataLoader(dataset_tr, - collate_fn=dataset_tr.collator, - batch_sampler=batch_sampler, - num_workers=kwargs.get("dataset_conf").get("num_workers", 4), - pin_memory=True) + + dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler) + dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val) + + trainer = Trainer(local_rank=local_rank, + use_ddp=use_ddp, + resume=kwargs.get("resume", True), + device=kwargs["device"], + **kwargs.get("train_conf"), + ) + + scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None + scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler + + trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler) + + tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard") + os.makedirs(tensorboard_dir, exist_ok=True) + try: + from tensorboardX import SummaryWriter + writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None + except: + writer = None - dataloader_val = torch.utils.data.DataLoader(dataset_val, - collate_fn=dataset_val.collator, - batch_sampler=batch_sampler_val, - num_workers=kwargs.get("dataset_conf").get("num_workers", 4), - pin_memory=True) - trainer = Trainer( - model=model, - optim=optim, - scheduler=scheduler, - dataloader_train=dataloader_tr, - dataloader_val=dataloader_val, - local_rank=local_rank, - use_ddp=use_ddp, - use_fsdp=use_fsdp, - output_dir=kwargs.get("output_dir", "./exp"), - resume=kwargs.get("resume", True), - **kwargs.get("train_conf"), - ) - trainer.run() - - if use_ddp or use_fsdp: - torch.distributed.destroy_process_group() + for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): + time1 = time.perf_counter() + trainer.train_epoch( + model=model, + optim=optim, + scheduler=scheduler, + scaler=scaler, + dataloader_train=dataloader_tr, + dataloader_val=dataloader_val, + epoch=epoch, + writer=writer + ) + + trainer.validate_epoch( + model=model, + dataloader_val=dataloader_val, + epoch=epoch, + writer=writer + ) + + trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) + + scheduler.step() + + time2 = time.perf_counter() + time_escaped = (time2 - time1) / 3600.0 + logging.info( + f"\nrank: {local_rank}, " + f"time_escaped_epoch: {time_escaped:.3f} hours, " + f"estimated to finish {trainer.max_epoch} " + f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n") + + + if trainer.rank == 0: + average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) + + trainer.close() + diff --git a/funasr/datasets/audio_datasets/jsonl2scp.py b/funasr/datasets/audio_datasets/jsonl2scp.py new file mode 100644 index 000000000..9a2b023a3 --- /dev/null +++ b/funasr/datasets/audio_datasets/jsonl2scp.py @@ -0,0 +1,62 @@ +import os +import json +import torch +import logging +import hydra +from omegaconf import DictConfig, OmegaConf +import concurrent.futures +import librosa +import torch.distributed as dist + + + +def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file): + + wav_f = open(wav_scp_file, "w") + text_f = open(text_file, "w") + with open(jsonl_file, encoding='utf-8') as fin: + for line in fin: + data = json.loads(line.strip()) + + prompt = data.get("prompt", "") + source = data[data_type_list[0]] + target = data[data_type_list[1]] + source_len = data.get("source_len", 1) + target_len = data.get("target_len", 0) + if "aishell" in source: + target = target.replace(" ", "") + key = data["key"] + wav_f.write(f"{key}\t{source}\n") + wav_f.flush() + text_f.write(f"{key}\t{target}\n") + text_f.flush() + + wav_f.close() + text_f.close() + + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + + kwargs = OmegaConf.to_container(cfg, resolve=True) + + scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt")) + if isinstance(scp_file_list, str): + scp_file_list = eval(scp_file_list) + data_type_list = kwargs.get("data_type_list", ("source", "target")) + jsonl_file = kwargs.get("jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl") + gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list) + + +""" +python -m funasr.datasets.audio_datasets.json2scp \ +++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ +++data_type_list='["source", "target"]' \ +++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl +""" + +if __name__ == "__main__": + main_hydra() + + \ No newline at end of file diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py index fe840e262..c728d9c10 100644 --- a/funasr/datasets/llm_datasets_vicuna/samplers.py +++ b/funasr/datasets/llm_datasets_vicuna/samplers.py @@ -142,9 +142,9 @@ class DistributedSamplerWarp(BatchSampler): def set_epoch(self, epoch): self.epoch = epoch - +@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn") def CustomDistributedBatchSampler_fn(dataset, **kwargs): - dataloader_args = {"dataset": dataset} + dataloader_args = {} dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs) dataloader_args["num_workers"] = kwargs.get("num_workers", 4) dataloader_args["pin_memory"] = kwargs.get("pin_memory", True) diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py index a6096b29a..06b21939f 100644 --- a/funasr/models/llm_asr_nar/model.py +++ b/funasr/models/llm_asr_nar/model.py @@ -264,7 +264,7 @@ class LLMASRNAR(nn.Module): audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=None) - if len(kwargs.get("data_type")) > 1: + if len(kwargs.get("data_type", [])) > 1: audio_sample_list, text_token_int_list = audio_sample_list text_token_int = text_token_int_list[0].replace(" ", "") text_token_int = tokenizer.encode(text_token_int) @@ -561,7 +561,7 @@ class LLMASRNARPrompt(nn.Module): audio_mask = kwargs.get("audio_mask", None) audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None text_token_int = kwargs.get("text_token_int", None) - if audio_token_lengths is None: + if audio_token_lengths is None and text_token_int is not None: audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64) batch = {"speech": speech, "speech_lengths": speech_lengths} @@ -572,7 +572,9 @@ class LLMASRNARPrompt(nn.Module): mask=enc_mask, target_label_length=audio_token_lengths, ) - loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) + loss_pre = 0.0 + if audio_token_lengths is not None: + loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) return pre_acoustic_embeds, pre_token_length, loss_pre @@ -603,10 +605,12 @@ class LLMASRNARPrompt(nn.Module): audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=None) - if len(kwargs.get("data_type")) > 1: + if len(kwargs.get("data_type", [])) > 1: audio_sample_list, text_token_int_list = audio_sample_list - text_token_int = text_token_int_list[0].replace(" ", "") + text_token_int = text_token_int_list[0] text_token_int = tokenizer.encode(text_token_int) + if text_token_int[0] == tokenizer.bos_token_id: + text_token_int = text_token_int[1:] else: text_token_int = None time2 = time.perf_counter() @@ -621,24 +625,30 @@ class LLMASRNARPrompt(nn.Module): speech_lengths = speech_lengths.to(device=kwargs["device"]) # Encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int) + res = self.encode(speech, speech_lengths, text_token_int=text_token_int) + encoder_out = res[0] # adaptor encoder_out = self.adaptor(encoder_out) prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt) prompt_ids = tokenizer.encode(prompt_pre) + if prompt_ids[0] == tokenizer.bos_token_id: + prompt_ids = prompt_ids[1:] + # prompt_ids = prompt_ids + [tokenizer.pad_token_id] prompt_length = len(prompt_ids) prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"]) + pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"]) if hasattr(self.llm.model, "embed_tokens"): inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + pad = self.llm.model.embed_tokens(pad) elif hasattr(self.llm.model.model, "embed_tokens"): inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) else: inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) - inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio] + inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio] attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"]) # model_outputs = self.llm.generate( @@ -662,8 +672,11 @@ class LLMASRNARPrompt(nn.Module): preds = torch.argmax(model_outputs.logits, -1) text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) - text = text[0].split(': ')[-1] + text = text[0].split(':')[-1] text = text.strip() + if text.startswith("Please\n "): + text = text.replace("Please\n ", "") + text = text.strip() # preds = torch.argmax(model_outputs.logits, -1) diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py new file mode 100644 index 000000000..6a3b83bb8 --- /dev/null +++ b/funasr/train_utils/trainer_llm.py @@ -0,0 +1,462 @@ +import os +import time +import torch +import logging +from tqdm import tqdm +from datetime import datetime +import torch.distributed as dist +from torch.cuda.amp import autocast, GradScaler +from contextlib import nullcontext, contextmanager +from pathlib import Path + +from funasr.train_utils.device_funcs import to_device +from funasr.train_utils.recursive_op import recursive_average +from funasr.train_utils.average_nbest_models import average_checkpoints +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + +@contextmanager +def maybe_autocast(enabled): + if enabled: + with autocast(): + yield + else: + yield + +class Trainer: + """ + A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, + and optionally resuming from a saved checkpoint. + + Attributes: + max_epoch (int): Maximum number of epochs for training. + model (torch.nn.Module): The model to be trained. + optim (torch.optim.Optimizer): The optimizer to use for training. + scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. + dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset. + dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset. + output_dir (str): Directory where model checkpoints will be saved. + resume (str, optional): Path to a checkpoint to resume training from. + """ + + def __init__(self, + local_rank, + use_ddp: bool = False, + use_fsdp: bool = False, + use_fp16: bool = False, + output_dir: str="./", + **kwargs): + """ + Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings. + + Args: + model (torch.nn.Module): The model to be trained. + optim (torch.optim.Optimizer): The optimizer to use for training. + scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. + dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset. + dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset. + **kwargs: Additional keyword arguments: + max_epoch (int): The maximum number of epochs for training. + output_dir (str): The directory where model checkpoints will be saved. Default is './'. + resume (str, optional): The file path to a checkpoint to resume training from. + """ + + self.output_dir = output_dir + self.resume = kwargs.get('resume', True) + self.start_epoch = 0 + self.max_epoch = kwargs.get('max_epoch', 100) + self.local_rank = local_rank + self.use_ddp = use_ddp + self.use_fsdp = use_fsdp + self.device = kwargs.get('device', "cuda") + self.avg_nbest_model = kwargs.get("avg_nbest_model", 5) + # self.kwargs = kwargs + self.log_interval = kwargs.get("log_interval", 50) + self.batch_total = 0 + self.use_fp16 = use_fp16 + self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True) + # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None + # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler + # self.scaler = scaler + self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000) + self.accum_grad = kwargs.get("accum_grad", 1) + self.grad_clip = kwargs.get("grad_clip", 10.0) + self.grad_clip_type = kwargs.get("grad_clip_type", 2.0) + self.validate_interval = kwargs.get("validate_interval", 5000) + + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + logging.warning("distributed is not initialized, only single shard") + self.rank = rank + self.world_size = world_size + + + + + def save_checkpoint(self, epoch, + step=None, + model=None, + optim=None, + scheduler=None, + scaler=None, + ): + """ + Saves a checkpoint containing the model's state, the optimizer's state, + and the scheduler's state at the end of the given epoch. This method is + intended to be called at the end of each epoch to save the training progress. + + Args: + epoch (int): The epoch number at which the checkpoint is being saved. + """ + if self.rank == 0: + state = { + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'optimizer': optim.state_dict(), + 'scheduler': scheduler.state_dict(), + } + if scaler: + state["scaler_state"] = scaler.state_dict() + # Create output directory if it does not exist + os.makedirs(self.output_dir, exist_ok=True) + if step is None: + filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') + else: + filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}') + + torch.save(state, filename) + + print(f'\nCheckpoint saved to {filename}\n') + latest = Path(os.path.join(self.output_dir, f'model.pt')) + torch.save(state, latest) + + if self.use_ddp or self.use_fsdp: + dist.barrier() + + def resume_checkpoint(self, + model=None, + optim=None, + scheduler=None, + scaler=None, + ): + """ + Resumes training from a checkpoint at the given file path. + Loads the model's state, the optimizer's state, and the scheduler's state. + + Args: + resume_path (str): The file path to the checkpoint to resume from. + """ + if self.resume: + ckpt = os.path.join(self.output_dir, "model.pt") + if os.path.isfile(ckpt): + checkpoint = torch.load(ckpt) + self.start_epoch = checkpoint['epoch'] + 1 + # self.model.load_state_dict(checkpoint['state_dict']) + src_state = checkpoint['state_dict'] + dst_state = model.state_dict() + for k in dst_state.keys(): + if not k.startswith("module.") and "module."+k in src_state.keys(): + k_ddp = "module."+k + else: + k_ddp = k + if k_ddp in src_state.keys(): + dst_state[k] = src_state[k_ddp] + else: + print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") + + model.load_state_dict(dst_state) + optim.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + if scaler is not None and 'scaler_state' in checkpoint: + scaler.load_state_dict(checkpoint['scaler_state']) + print(f"Checkpoint loaded successfully from '{ckpt}'") + else: + print(f"No checkpoint found at '{ckpt}', does not resume status!") + + if self.use_ddp or self.use_fsdp: + dist.barrier() + + # def train(self): + # """ + # Starts the training process, iterating over epochs, training the model, + # and saving checkpoints at the end of each epoch. + # """ + # if self.resume: + # self.resume_checkpoint(self.output_dir) + # + # for epoch in range(self.start_epoch, self.max_epoch + 1): + # time1 = time.perf_counter() + # self.train_epoch(epoch) + # + # + # + # if self.use_ddp or self.use_fsdp: + # dist.barrier() + # + # self._validate_epoch(epoch) + # + # if self.use_ddp or self.use_fsdp: + # dist.barrier() + # + # + # if self.rank == 0: + # self._save_checkpoint(epoch) + # + # if self.use_ddp or self.use_fsdp: + # dist.barrier() + # + # self.scheduler.step() + # + # time2 = time.perf_counter() + # time_escaped = (time2 - time1)/3600.0 + # print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n") + # + # if self.rank == 0: + # average_checkpoints(self.output_dir, self.avg_nbest_model) + # + # if self.use_ddp or self.use_fsdp: + # dist.barrier() + # + # + # if writer: + # writer.close() + # + + def train_epoch(self, + model=None, + optim=None, + scheduler=None, + scaler=None, + dataloader_train=None, + dataloader_val=None, + epoch=None, + writer=None, + ): + """ + Defines the training process for a single epoch with gradient accumulation. + Args: + epoch (int): The current epoch number. + """ + model.train() + + + # Set the number of steps for gradient accumulation + accum_grad = self.accum_grad + # Initialize the gradient accumulation + optim.zero_grad() + speed_stats = {} + time5 = time.perf_counter() + + for batch_idx, batch in enumerate(dataloader_train): + self.batch_total += 1 + time1 = time.perf_counter() + speed_stats["data_load"] = f"{time1-time5:0.3f}" + + batch = to_device(batch, self.device) + + my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext + with my_context(): + time2 = time.perf_counter() + with maybe_autocast(self.use_fp16): + retval = model(**batch) + + if self.disable_gpu_cache: torch.cuda.empty_cache() + + time3 = time.perf_counter() + speed_stats["forward_time"] = f"{time3 - time2:0.3f}" + loss, stats, weight = retval + stats = {k: v for k, v in stats.items() if v is not None} + if self.use_ddp or self.use_fsdp: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed=True) + # Now weight is summation over all workers + loss /= weight + # Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= self.world_size + # Scale the loss since we're not updating for every mini-batch + loss = loss / accum_grad + if self.use_fp16: + scaler.scale(loss).backward() + else: + loss.backward() + time4 = time.perf_counter() + speed_stats["backward_time"] = f"{time4 - time3:0.3f}" + + # Perform an optimizer step only after accumulating enough gradients + if (batch_idx + 1) % accum_grad == 0: + # Perform gradient clipping if it is set + if self.grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + max_norm=self.grad_clip, + norm_type=self.grad_clip_type, + ) + if not torch.isfinite(grad_norm): + logging.warning( + f"The grad norm is {grad_norm}. Skipping updating the model." + ) + optim.zero_grad() # Reset gradients + continue + + # Execute an optimization step (update model parameters) + if self.use_ddp or self.use_fsdp: + dist.barrier() + if self.use_fp16: + scaler.step(optim) + scaler.update() + else: + optim.step() + scheduler.step() + # Clear gradients for the next accumulation stage + optim.zero_grad(set_to_none=True) + total_time = f"{time.perf_counter() - time5:0.3f}" + time5 = time.perf_counter() + speed_stats["optim_time"] = f"{time5 - time4:0.3f}" + + speed_stats["total_time"] = total_time + lr = scheduler.get_last_lr()[0] + + self.log(epoch, batch_idx, + batch_num_epoch=len(dataloader_train), + lr=lr, + loss=loss.detach().cpu().item(), + speed_stats=speed_stats, + stats=stats, + writer=writer, + tag="train", + ) + + if (batch_idx + 1) % self.validate_interval == 0: + self.validate_epoch( + model=model, + dataloader_val=dataloader_val, + epoch=epoch, + writer=writer + ) + + if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0: + self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1) + + + if self.use_ddp or self.use_fsdp: + dist.barrier() + + + + def validate_epoch(self, + model=None, + dataloader_val=None, + epoch=None, + writer=None, + **kwargs, + ): + """ + Defines the validation process for a single epoch. + Should be implemented with the actual model validation steps. + + Args: + epoch (int): The current epoch number. + """ + model.eval() + + with torch.no_grad(): + + speed_stats = {} + time5 = time.perf_counter() + for batch_idx, batch in enumerate(dataloader_val): + time1 = time.perf_counter() + speed_stats["data_load"] = f"{time1 - time5:0.3f}" + batch = to_device(batch, self.device) + time2 = time.perf_counter() + retval = model(**batch) + time3 = time.perf_counter() + speed_stats["forward_time"] = f"{time3 - time2:0.3f}" + loss, stats, weight = retval + stats = {k: v for k, v in stats.items() if v is not None} + if self.use_ddp or self.use_fsdp: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed=True) + # Now weight is summation over all workers + loss /= weight + # Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= self.world_size + # Scale the loss since we're not updating for every mini-batch + loss = loss + time4 = time.perf_counter() + + + self.log(epoch, batch_idx, + batch_num_epoch=len(dataloader_val), + lr=0.0, + loss=loss.detach().cpu().item(), + speed_stats=speed_stats, + stats=stats, + writer=writer, + tag="train", + ) + + model.train() + + + def log(self, + epoch=0, + batch_idx=0, + batch_num_epoch=-1, + lr=0.0, + loss=0.0, + speed_stats=None, + stats=None, + writer=None, + tag="train", + ): + + if (batch_idx + 1) % self.log_interval == 0: + + gpu_info = "GPU, memory: {:.3f} GB, " \ + "{:.3f} GB, " \ + "{:.3f} GB, " \ + "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024, + torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, + torch.cuda.memory_reserved() / 1024 / 1024 / 1024, + torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, + ) + + time_now = datetime.now() + time_now = time_now.strftime("%Y-%m-%d %H:%M:%S") + description = ( + f"{time_now}, " + f"rank: {self.local_rank}, " + f"epoch: {epoch}/{self.max_epoch}, " + f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " + f"(loss: {loss:.3f}), " + f"(lr: {lr:.3e}), " + f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, " + f"{speed_stats}, " + f"{gpu_info}" + ) + logging.info(description) + + if writer is not None: + writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total) + writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total) + for key, var in stats.items(): + writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total) + for key, var in speed_stats.items(): + writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total) + + def close(self, writer=None): + if writer is not None: + writer.close() + + if self.use_ddp or self.use_fsdp: + torch.distributed.destroy_process_group() \ No newline at end of file