Dev gzf deepspeed (#1732)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* add

* add

* add

* add

* deepspeed

* update with main (#1731)

* c++ runtime adapt to 1.0 (#1724)

* adapt vad runtime to 1.0

* add json

* change yml name

* add func LoadVocabFromJson

* add token file for InitAsr

* add token path for OfflineStream

* add funcOpenYaml

* add token file for InitPunc

* add token file for stream

* update punc-model

* update funasr-wss-server

* update runtime_sdk_download_tool.py

* update docker list

* Delete docs/images/wechat.png

* Add files via upload

* Emo2Vec限定选择的情感类别 (#1730)

* 限定选择的情感类别

* 使用none来禁用情感标签输出

* 修改输出接口

* 使用unuse来禁用token

---------

Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>

* bugfix

* v1.0.27

* update docs

* hf hub

* Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680)

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>

* docs

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com>
This commit is contained in:
zhifu gao 2024-05-15 19:48:50 +08:00 committed by GitHub
parent d50edc297a
commit a0f03bd2a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1553 additions and 58 deletions

View File

@ -28,6 +28,7 @@
<a name="whats-new"></a> <a name="whats-new"></a>
## What's new: ## What's new:
- 2024/05/15emotion recognition models are new supported. [emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)[emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)[emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary). currently supports the following categories: 0: angry 1: happy 2: neutral 3: sad 4: unknown.
- 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6Real-time Transcription Service 1.10 releasedadapting to FunASR 1.0 model structure([docs](runtime/readme.md)) - 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6Real-time Transcription Service 1.10 releasedadapting to FunASR 1.0 model structure([docs](runtime/readme.md))
- 2024/03/05Added the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio). - 2024/03/05Added the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio).
- 2024/03/05Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py). - 2024/03/05Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
@ -84,10 +85,11 @@ FunASR has open-sourced a large number of pre-trained models on industrial data.
| fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M | | fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M |
| fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M | | fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M |
| cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M | | cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M |
| Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | | Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | | Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B | | Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B | | Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |

View File

@ -29,6 +29,7 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
<a name="最新动态"></a> <a name="最新动态"></a>
## 最新动态 ## 最新动态
- 2024/05/15新增加情感识别模型[emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)[emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)[emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary),输出情感类别为:生气/angry开心/happy中立/neutral难过/sad。
- 2024/05/15: 中文离线文件转写服务 4.5、英文离线文件转写服务 1.6、中文实时语音听写服务 1.10 发布适配FunASR 1.0模型结构;详细信息参阅([部署文档](runtime/readme_cn.md)) - 2024/05/15: 中文离线文件转写服务 4.5、英文离线文件转写服务 1.6、中文实时语音听写服务 1.10 发布适配FunASR 1.0模型结构;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/03/05新增加Qwen-Audio与Qwen-Audio-Chat音频文本模态大模型在多个音频领域测试榜单刷榜中支持语音对话详细用法见 [示例](examples/industrial_data_pretraining/qwen_audio)。 - 2024/03/05新增加Qwen-Audio与Qwen-Audio-Chat音频文本模态大模型在多个音频领域测试榜单刷榜中支持语音对话详细用法见 [示例](examples/industrial_data_pretraining/qwen_audio)。
- 2024/03/05新增加Whisper-large-v3模型支持多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。 - 2024/03/05新增加Whisper-large-v3模型支持多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
@ -75,19 +76,20 @@ FunASR开源了大量在工业数据上预训练模型您可以在[模型许
(注:⭐ 表示ModelScope模型仓库🤗 表示Huggingface模型仓库🍀表示OpenAI模型仓库 (注:⭐ 表示ModelScope模型仓库🤗 表示Huggingface模型仓库🍀表示OpenAI模型仓库
| 模型名字 | 任务详情 | 训练数据 | 参数量 | | 模型名字 | 任务详情 | 训练数据 | 参数量 |
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:------------:|:----:| |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | 语音识别,带时间戳输出,非实时 | 60000小时中文 | 220M | | paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | 语音识别,带时间戳输出,非实时 | 60000小时中文 | 220M |
| paraformer-zh-streaming <br> ( [](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时中文 | 220M | | paraformer-zh-streaming <br> ( [](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时中文 | 220M |
| paraformer-en <br> ( [](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M | | paraformer-en <br> ( [](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M |
| conformer-en <br> ( [](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M | | conformer-en <br> ( [](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M |
| ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M中文与英文 | 1.1B | | ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M中文与英文 | 1.1B |
| fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | 语音端点检测,实时 | 5000小时中文与英文 | 0.4M | | fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | 语音端点检测,实时 | 5000小时中文与英文 | 0.4M |
| fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | 字级别时间戳预测 | 50000小时中文 | 38M | | fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | 字级别时间戳预测 | 50000小时中文 | 38M |
| cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | 说话人确认/分割 | 5000小时 | 7.2M | | cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | 说话人确认/分割 | 5000小时 | 7.2M |
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | 语音识别,带时间戳输出,非实时 | 多语言 | 1550 M | | Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | 语音识别,带时间戳输出,非实时 | 多语言 | 1550 M |
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | 音频文本多模态大模型(预训练) | 多语言 | 8B | | Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | 音频文本多模态大模型(预训练) | 多语言 | 8B |
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | 音频文本多模态大模型chat版本 | 多语言 | 8B | | Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | 音频文本多模态大模型chat版本 | 多语言 | 8B |
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | 情感识别模型 | 40000小时4种情感类别 | 300M |
<a name="快速开始"></a> <a name="快速开始"></a>
## 快速开始 ## 快速开始

View File

@ -6,14 +6,20 @@
from funasr import AutoModel from funasr import AutoModel
# model="iic/emotion2vec_base" # model="iic/emotion2vec_base"
# model="iic/emotion2vec_base_finetuned"
# model="iic/emotion2vec_plus_seed"
# model="iic/emotion2vec_plus_base"
model = "iic/emotion2vec_plus_large"
model = AutoModel( model = AutoModel(
model="iic/emotion2vec_base_finetuned", model=model,
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="master", # vad_model_revision="master",
# vad_kwargs={"max_single_segment_time": 2000}, # vad_kwargs={"max_single_segment_time": 2000},
) )
wav_file = f"{model.model_path}/example/test.wav" wav_file = f"{model.model_path}/example/test.wav"
res = model.generate( res = model.generate(
wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False
) )

View File

@ -7,8 +7,8 @@ from funasr import AutoModel
model = AutoModel( model = AutoModel(
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope", model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000}, # vad_kwargs={"max_single_segment_time": 30000},
) )
@ -21,6 +21,7 @@ DecodingOptions = {
"language": "auto", "language": "auto",
"fp16": True, "fp16": True,
"gain_event": True, "gain_event": True,
"beam_size": 5,
} }
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions) res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions)

View File

@ -21,6 +21,7 @@ DecodingOptions = {
"language": "auto", "language": "auto",
"fp16": True, "fp16": True,
"gain_event": True, "gain_event": True,
"beam_size": 5,
} }
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5) res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)

View File

@ -223,6 +223,7 @@ def main(**kwargs):
torch.cuda.empty_cache() torch.cuda.empty_cache()
trainer.start_data_split_i = 0
trainer.validate_epoch( trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
) )
@ -240,6 +241,8 @@ def main(**kwargs):
f"estimated to finish {trainer.max_epoch} " f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n" f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
) )
trainer.train_acc_avg = 0.0
trainer.train_loss_avg = 0.0
if trainer.rank == 0: if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)

241
funasr/bin/train_ds.py Normal file
View File

@ -0,0 +1,241 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
import os
import sys
import torch
import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer_ds import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
try:
import deepspeed
except:
deepspeed = None
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
def main(**kwargs):
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if local_rank == 0:
tables.print()
use_ddp = world_size > 1
use_fsdp = kwargs.get("use_fsdp", False)
use_deepspeed = kwargs.get("use_deepspeed", False)
if use_deepspeed:
logging.info(f"use_deepspeed: {use_deepspeed}")
deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
elif use_ddp or use_fsdp:
logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
torch.cuda.set_device(local_rank)
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
# save config.yaml
if rank == 0:
prepare_model_dir(**kwargs)
# parse kwargs
kwargs = model.kwargs
kwargs["device"] = device
tokenizer = kwargs["tokenizer"]
frontend = kwargs["frontend"]
model = model.model
del kwargs["model"]
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
if local_rank == 0:
logging.info(f"{model_summary(model)}")
trainer = Trainer(
rank=rank,
local_rank=local_rank,
world_size=world_size,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
device=kwargs["device"],
output_dir=kwargs.get("output_dir", "./exp"),
**kwargs.get("train_conf"),
)
model = trainer.warp_model(model)
kwargs["device"] = next(model.parameters()).device
trainer.device = kwargs["device"]
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if use_deepspeed:
args = OmegaConf.create({"deepspeed_config": kwargs.get("deepspeed_config", "")})
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=optim,
lr_scheduler=scheduler,
model_parameters=model.parameters(),
)
# dataset
logging.info("Build dataloader")
dataloader_class = tables.dataloader_classes.get(
kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
)
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
)
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
try:
from tensorboardX import SummaryWriter
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
time1 = time.perf_counter()
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.train_epoch(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
dataloader_train=dataloader_tr,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
start_step=trainer.start_step,
)
trainer.start_step = 0
torch.cuda.empty_cache()
trainer.start_data_split_i = 0
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
scheduler.step()
trainer.step_in_epoch = 0
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
)
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
)
trainer.train_acc_avg = 0.0
trainer.train_loss_avg = 0.0
if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
trainer.close()
if __name__ == "__main__":
main_hydra()

View File

@ -146,10 +146,9 @@ class EspnetStyleBatchSampler(DistributedSampler):
start_idx = self.rank * batches_per_rank start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank end_idx = start_idx + batches_per_rank
rank_batches = buffer_batches[start_idx + self.start_step : end_idx] rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
if self.start_step > 0: logging.info(
logging.info( f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}" )
)
# Return an iterator over the batches for the current rank # Return an iterator over the batches for the current rank
return iter(rank_batches) return iter(rank_batches)

View File

@ -53,6 +53,12 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
self.prompt_ids_len = 0 self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5) self.retry = kwargs.get("retry", 5)
self.permute = False
from funasr.frontends.whisper_frontend import WhisperFrontend
if isinstance(self.frontend, WhisperFrontend):
self.permute = True
def get_source_len(self, index): def get_source_len(self, index):
item = self.index_ds[index] item = self.index_ds[index]
return self.index_ds.get_source_len(item) return self.index_ds.get_source_len(item)
@ -92,7 +98,8 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
if speech_lengths > self.batch_size: if speech_lengths > self.batch_size:
continue continue
speech = speech.permute(0, 2, 1) if self.permute:
speech = speech.permute(0, 2, 1)
target = item["target"] target = item["target"]
if self.preprocessor_text: if self.preprocessor_text:
target = self.preprocessor_text(target) target = self.preprocessor_text(target)
@ -100,8 +107,14 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
task = item.get("prompt", "<|ASR|>") task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>") text_language = item.get("text_language", "<|zh|>")
prompt = f"{self.sos}{task}{text_language}" if isinstance(self.sos, str):
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") prompt = f"{self.sos}{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
else:
prompt = f"{task}{text_language}"
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
prompt_ids = [self.sos] + prompt_ids
prompt_ids_len = len(prompt_ids) - 1 # [sos, task] prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len self.prompt_ids_len = prompt_ids_len
@ -110,7 +123,10 @@ class SenseVoiceDataset(torch.utils.data.Dataset):
if target_ids_len > 200: if target_ids_len > 200:
continue continue
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] if isinstance(self.eos, str):
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
else:
eos = [self.eos]
ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos]
ids_lengths = len(ids) ids_lengths = len(ids)

View File

@ -966,3 +966,415 @@ class SenseVoiceFSMN(nn.Module):
ibest_writer["text"][key[i]] = text ibest_writer["text"][key[i]] = text
return results, meta_data return results, meta_data
@tables.register("model_classes", "SenseVoiceSANM")
class SenseVoiceSANM(nn.Module):
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.specaug = specaug
self.encoder = encoder
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
_, text_tokens = text.shape
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
)
loss = loss_att
stats = {}
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_tokens"] = text_tokens * batch_size
stats["batch_size_real_tokens"] = text_lengths.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, (tuple, list)):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
stats = {}
# 1. Forward decoder
ys_pad[ys_pad == -1] = 0
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
if isinstance(decoder_out, (list, tuple)):
decoder_out = decoder_out[0]
# 2. Compute attention loss
mask = torch.ones_like(ys_pad) * (-1)
ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
ys_pad_mask[ys_pad_mask == 0] = -1
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
preds = torch.argmax(decoder_out, -1)
acc_att = compute_accuracy(
preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
)
return loss_att, acc_att, None, None
def init_beam_search(
self,
**kwargs,
):
from .search import BeamSearch
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(self.vocab_size),
)
weights = dict(
decoder=1.0,
ctc=0.0,
lm=0.0,
ngram=0.0,
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 5),
weights=weights,
scorers=scorers,
sos=None,
eos=None,
vocab_size=self.vocab_size,
token_list=None,
pre_beam_score_key="full",
)
self.beam_search = beam_search
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if not hasattr(self, "beam_search") or self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
if (
isinstance(kwargs.get("data_type", None), (list, tuple))
and len(kwargs.get("data_type", [])) > 1
):
audio_sample_list, text_token_int_list = audio_sample_list
text_token_int = text_token_int_list[0]
else:
text_token_int = None
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
DecodingOptions = kwargs.get("DecodingOptions", {})
task = DecodingOptions.get("task", "ASR")
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
eos = kwargs.get("model_conf").get("eos")
eos_int = tokenizer.encode(eos, allowed_special="all")
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
# Paramterts for rich decoding
self.beam_search.emo_unk = tokenizer.encode(
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
)[0]
self.beam_search.emo_unk_score = 1
self.beam_search.emo_tokens = tokenizer.encode(
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
allowed_special="all",
)
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
self.beam_search.event_bg_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
allowed_special="all",
)
self.beam_search.event_ed_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
allowed_special="all",
)
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
if text_token_int is not None:
i = 0
results = []
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"1best_recog"]
# 1. Forward decoder
ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
None, :
]
ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
kwargs["device"]
)[None, :]
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
token_int = decoder_out.argmax(-1)[0, :].tolist()
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0],
maxlenratio=kwargs.get("maxlenratio", 0.0),
minlenratio=kwargs.get("minlenratio", 0.0),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# # remove blank symbol id, which is assumed to be 0
# token_int = list(
# filter(
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
# )
# )
# Change integer-ids to tokens
# token = tokenizer.ids2tokens(token_int)
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data

View File

@ -20,6 +20,7 @@ class SentencepiecesTokenizer(BaseTokenizer):
# "TypeError: can't pickle SwigPyObject objects", # "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()". # when giving it as argument of "multiprocessing.Process()".
self.sp = None self.sp = None
self._build_sentence_piece_processor()
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}(model="{self.bpemodel}")' return f'{self.__class__.__name__}(model="{self.bpemodel}")'
@ -38,10 +39,13 @@ class SentencepiecesTokenizer(BaseTokenizer):
self._build_sentence_piece_processor() self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens)) return self.sp.DecodePieces(list(tokens))
def encode(self, line: str) -> List[int]: def encode(self, line: str, **kwargs) -> List[int]:
self._build_sentence_piece_processor() self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line) return self.sp.EncodeAsIds(line)
def decode(self, line: List[int]): def decode(self, line: List[int], **kwargs):
self._build_sentence_piece_processor() self._build_sentence_piece_processor()
return self.sp.DecodeIds(line) return self.sp.DecodeIds(line)
def get_vocab_size(self):
return self.sp.GetPieceSize()

View File

@ -382,8 +382,6 @@ class Trainer:
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None} stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
@ -398,34 +396,28 @@ class Trainer:
# Multiply world_size because DistributedDataParallel # Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size. # automatically normalizes the gradient by world_size.
loss *= self.world_size loss *= self.world_size
# loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch # Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad loss = loss / accum_grad
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
if self.use_fp16: if self.use_fp16:
scaler.scale(loss).backward() scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
time4 = time.perf_counter() time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}" speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
self.train_loss_avg = ( self.train_loss_avg = (
self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item() self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
) / self.step_in_epoch + loss.detach().cpu().item()
) / (batch_idx + kwargs.get("start_step", 0) + 1)
if "acc" in stats: if "acc" in stats:
self.train_acc_avg = ( self.train_acc_avg = (
self.train_acc_avg * (self.step_in_epoch - 1) self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+ stats["acc"].detach().cpu().item() + stats["acc"].detach().cpu().item()
) / self.step_in_epoch ) / (batch_idx + kwargs.get("start_step", 0) + 1)
if self.use_ddp or self.use_fsdp:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
self.device
)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
# Perform an optimizer step only after accumulating enough gradients # Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0: if (batch_idx + 1) % accum_grad == 0:
@ -454,8 +446,22 @@ class Trainer:
scheduler.step() scheduler.step()
# Clear gradients for the next accumulation stage # Clear gradients for the next accumulation stage
optim.zero_grad(set_to_none=True) optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
if self.use_ddp or self.use_fsdp:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
self.device
)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
time5 = time.perf_counter() time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}" speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time speed_stats["total_time"] = total_time
@ -662,9 +668,9 @@ class Trainer:
f"data_slice: {data_split_i}/{data_split_num}, " f"data_slice: {data_split_i}/{data_split_num}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, " f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), " f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), " f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), " f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
f"(acc_avg_epoch: {acc_avg_epoch:.3f}), " f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), " f"(lr: {lr:.3e}), "
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, " f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, " f"{speed_stats}, "

View File

@ -0,0 +1,800 @@
import math
import os
import time
import torch
import logging
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
from pathlib import Path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
try:
import wandb
except:
wandb = None
@contextmanager
def maybe_autocast(enabled):
if enabled:
with autocast():
yield
else:
yield
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(
self,
rank=0,
local_rank=0,
world_size=1,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.rank = kwargs.get("rank", 0)
self.local_rank = local_rank
self.world_size = world_size
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.use_deepspeed = use_deepspeed
self.device = kwargs.get("device", "cuda")
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get("resume", True)
self.start_epoch = 0
self.max_epoch = kwargs.get("max_epoch", 100)
# self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.use_fp16 = use_fp16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
self.train_acc_avg = 0.0
self.train_loss_avg = 0.0
self.val_acc_avg = 0.0
self.val_loss_avg = 0.0
self.best_acc_idx = 0
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
self.val_acc_step_or_eoch = {}
self.val_loss_step_or_eoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
self.step_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
wandb.init(
config=kwargs,
project=kwargs.get("wandb_project", "my_project"),
entity=kwargs.get("wandb_team", "my_team"),
name=kwargs.get("wandb_exp_name", "my_exp"),
dir=output_dir,
job_type="training",
reinit=True,
)
def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
"val_acc_step_or_eoch": self.val_acc_step_or_eoch,
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
"train_loss_avg": kwargs.get("train_loss_avg", 0),
"train_acc_avg": kwargs.get("train_acc_avg", 0),
}
step = step_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
logging.info(f"\nCheckpoint saved to {filename}\n")
latest = Path(os.path.join(self.output_dir, f"model.pt"))
torch.save(state, latest)
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
self.val_acc_step_or_eoch[ckpt_name]
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
self.val_loss_step_or_eoch[ckpt_name]
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
if self.avg_keep_nbest_models_type == "acc":
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
else:
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
if key in self.saved_ckpts:
del self.saved_ckpts[key]
filename = os.path.join(self.output_dir, key)
logging.info(f"Delete: {filename}")
if os.path.exists(filename):
os.remove(filename)
if self.use_ddp or self.use_fsdp:
dist.barrier()
def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
model.load_state_dict(dst_state)
optim.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
if scaler is not None and "scaler_state" in checkpoint:
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (
checkpoint["val_acc_step_or_eoch"]
if "val_acc_step_or_eoch" in checkpoint
else {}
)
self.val_loss_step_or_eoch = (
checkpoint["val_loss_step_or_eoch"]
if "val_loss_step_or_eoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
self.start_data_split_i = (
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_in_epoch = (
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
self.train_loss_avg = (
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
)
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
if self.use_ddp or self.use_fsdp:
dist.barrier()
def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
time_beg = time.perf_counter()
time5 = time_beg
for batch_idx, batch in enumerate(dataloader_train):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
break
self.batch_total += 1
self.step_in_epoch += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
batch = to_device(batch, self.device)
my_context = nullcontext
if self.use_ddp or self.use_fsdp:
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
with my_context():
time2 = time.perf_counter()
loss_dict = {}
self.forward_step(model, batch, loss_dict=loss_dict)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
self.backward_step(model, scaler, loss_dict=loss_dict)
time4 = time.perf_counter()
speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
# self.train_loss_avg = (
# self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
# + loss.detach().cpu().item()
# ) / (batch_idx + kwargs.get("start_step", 0) + 1)
# if "acc" in stats:
# self.train_acc_avg = (
# self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
# + stats["acc"].detach().cpu().item()
# ) / (batch_idx + kwargs.get("start_step", 0) + 1)
self.update_step(model, optim, scheduler, scaler, loss_dict)
# Perform an optimizer step only after accumulating enough gradients
if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
if self.step_in_epoch % self.save_checkpoint_interval == 0:
self.save_checkpoint(
epoch,
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
train_loss_avg=self.train_loss_avg,
train_acc_avg=self.train_acc_avg,
)
time_beg = time.perf_counter()
else:
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if self.use_ddp or self.use_fsdp:
dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
def forward_step(self, model, batch, loss_dict={}):
with maybe_autocast(self.use_fp16):
retval = model(**batch)
if (
self.reset_gpu_cache
and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
):
torch.cuda.empty_cache()
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
# if self.use_ddp or self.use_fsdp:
# # Apply weighted averaging for loss and stats
# loss = (loss * weight.type(loss.dtype)).sum()
# # if distributed, this method can also apply all_reduce()
# # stats, weight = recursive_average(stats, weight, distributed=True)
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# # Now weight is summation over all workers
# loss /= weight.sum() # shape:[1] -> shape:[]
# # Multiply world_size because DistributedDataParallel
# # automatically normalizes the gradient by world_size.
# loss *= self.world_size
# loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss_dict["loss"] = loss
loss_dict["stats"] = stats
loss_dict["weight"] = weight
def backward_step(self, model, scaler, loss_dict={}):
loss = loss_dict["loss"]
if self.use_deepspeed:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if self.use_fp16:
scaler.scale(loss).backward()
else:
loss.backward()
def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
if (batch_idx + 1) % self.accum_grad == 0:
# Perform gradient clipping if it is set
if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=self.grad_clip,
norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(f"The grad norm is {grad_norm}. Skipping updating the model.")
optim.zero_grad() # Reset gradients
return
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16:
scaler.step(optim)
scaler.update()
else:
optim.step()
scheduler.step()
# Clear gradients for the next accumulation stage
optim.zero_grad(set_to_none=True)
if self.use_ddp or self.use_fsdp:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
self.device
)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
lr = scheduler.get_last_lr()[0]
batch_num_epoch = 1
if hasattr(dataloader_train, "__len__"):
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
batch_idx,
log_step=batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
tag="train",
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
)
def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
break
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
time2 = time.perf_counter()
retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
# stats, weight = recursive_average(stats, weight, distributed=True)
if self.use_ddp or self.use_fsdp:
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# Now weight is summation over all workers
loss /= weight.sum() # shape:[1] -> shape:[]
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
batch_idx + 1
)
if "acc" in stats:
self.val_acc_avg = (
self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
) / (batch_idx + 1)
if self.use_ddp or self.use_fsdp:
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
self.device
)
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
time5 = time.perf_counter()
batch_num_epoch = 1
if hasattr(dataloader_val, "__len__"):
batch_num_epoch = len(dataloader_val)
self.log(
epoch,
batch_idx,
batch_num_epoch=batch_num_epoch,
lr=0.0,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
tag="val",
)
else:
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if kwargs.get("step_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
def log(
self,
epoch=0,
batch_idx=0,
step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
speed_stats=None,
stats=None,
writer=None,
tag="train",
data_split_i=0,
data_split_num=1,
log_step=None,
**kwargs,
):
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
"GPU, memory: usage: {:.3f} GB, "
"peak: {:.3f} GB, "
"cache: {:.3f} GB, "
"cache_peak: {:.3f} GB".format(
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
)
)
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
f"{tag}, "
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
logging.info(description)
description_dict = {
f"rank{self.rank}_loss/{tag}": loss,
f"rank{self.rank}_lr/{tag}": lr,
}
if writer is not None:
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
for key, var in stats.items():
writer.add_scalar(
f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
for key, var in speed_stats.items():
writer.add_scalar(
f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
if self.use_wandb and wandb is not None:
wandb.log(
description_dict,
setp=self.batch_total,
)
def close(self, writer=None):
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
def warp_model(self, model, **kwargs):
if self.use_deepspeed:
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live,
)
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# NOTE(xcsong): look in detail how the memory estimator API works:
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
if int(os.environ.get("RANK", 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
logging.info("Estimating model states memory needs (zero3)...")
estimate_zero3_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
device = None # Init device later
pass # Init DeepSpeed later
elif self.use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
model = model.cuda(local_rank)
model = DDP(
model,
device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get(
"find_unused_parameters", False
),
)
# elif self.use_fsdp:
# # model = FSDP(model).cuda(local_rank)
#
# def custom_auto_wrap_policy(
# module: nn.Module,
# recurse: bool,
# nonwrapped_numel: int,
# # Additional custom arguments
# min_num_params: int = int(1e8),
# ) -> bool:
# # 根据自定义逻辑决定是否包装模块
# is_large = unwrapped_params >= min_num_params
# requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
# return is_large and requires_grad_uniform
#
# # Configure a custom `min_num_params`
# my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
# torch.cuda.set_device(local_rank)
# model = FSDP(
# model,
# auto_wrap_policy=custom_auto_wrap_policy,
# mixed_precision=None,
# device_id=torch.cuda.current_device(),
# )
else:
model = model.to(device=kwargs.get("device", "cuda"))
return model

View File

@ -70,14 +70,16 @@ def prepare_model_dir(**kwargs):
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file) OmegaConf.save(config=kwargs, f=yaml_file)
print(kwargs) logging.info(f"kwargs: {kwargs}")
logging.info("config.yaml is saved to: %s", yaml_file) logging.info("config.yaml is saved to: %s", yaml_file)
# model_path = kwargs.get("model_path") model_path = kwargs.get("model_path", None)
# if model_path is not None: if model_path is not None:
# config_json = os.path.join(model_path, "configuration.json") config_json = os.path.join(model_path, "configuration.json")
# if os.path.exists(config_json): if os.path.exists(config_json):
# shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")) shutil.copy(
config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
)
def extract_filename_without_extension(file_path): def extract_filename_without_extension(file_path):