Merge pull request #768 from alibaba-damo-academy/dev_lhn

Dev lhn
This commit is contained in:
hnluo 2023-07-21 15:07:34 +08:00 committed by GitHub
commit 0063e4356a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 566 additions and 41 deletions

View File

@ -1,6 +1,6 @@
# Speech Recognition
> **Note**:
> **Note**:
> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
## Inference
@ -36,7 +36,7 @@ chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
# first chunk, 600ms
speech_chunk = speech[0:chunk_stride]
speech_chunk = speech[0:chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
# next chunk, 600ms
@ -101,16 +101,16 @@ print(rec_result)
- `task`: `Tasks.auto_speech_recognition`
- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
#### Infer pipeline
- `audio_in`: the input to decode, which could be:
- `audio_in`: the input to decode, which could be:
- wav_path, `e.g.`: asr_example.wav,
- pcm_path, `e.g.`: asr_example.pcm,
- pcm_path, `e.g.`: asr_example.pcm,
- audio bytes stream, `e.g.`: bytes data from a microphone
- audio sample point`e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, the dtype is numpy.ndarray or torch.Tensor
- wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
- wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
```text
asr_example1 ./audios/asr_example1.wav
asr_example2 ./audios/asr_example2.wav
@ -168,15 +168,19 @@ If you decode the SpeechIO test sets, you can use textnorm with `stage`=3, and `
[finetune.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/finetune.py)
```python
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.msdatasets.audio.asr_dataset import ASRDataset
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
model=params.model,
data_dir=ds_dict,
@ -184,21 +188,32 @@ def modelscope_finetune(params):
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
lr=params.lr,
mate_params=params.param_dict)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
from funasr.utils.modelscope_param import modelscope_args
params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
params.output_dir = "./checkpoint" # 模型保存路径
params.data_path = "speech_asr_aishell1_trainsets" # 数据路径可以为modelscope中已上传数据也可以是本地数据
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 50 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", data_path="./data")
params.output_dir = "./checkpoint" # m模型保存路径
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 20 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
init_param = [] # 初始模型路径默认加载modelscope模型初始化例如: ["checkpoint/20epoch.pb"]
freeze_param = [] # 模型参数freeze, 例如: ["encoder"]
ignore_init_mismatch = True # 是否忽略模型参数初始化不匹配
use_lora = False # 是否使用lora进行模型微调
params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
if use_lora:
enable_lora = True
lora_bias = "all"
lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
params.param_dict.update(lora_config)
modelscope_finetune(params)
```
@ -215,6 +230,10 @@ python finetune.py &> log.txt &
- `batch_bins`: batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- `max_epoch`: number of training epoch
- `lr`: learning rate
- `init_param`: init model path, load modelscope model initialization by default. For example: ["checkpoint/20epoch.pb"]
- `freeze_param`: Freeze model parameters. For example["encoder"]
- `ignore_init_mismatch`: Ignore size mismatch when loading pre-trained model
- `use_lora`: Fine-tuning model use lora, more detail please refer to [LORA](https://arxiv.org/pdf/2106.09685.pdf)
- Training data formats
```sh

View File

@ -2,14 +2,6 @@
- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
- Model size: 220M
# Environments
- date: `Tue Nov 22 18:48:39 CST 2022`
- python version: `3.7.12`
- FunASR version: `0.1.0`
- pytorch version: `pytorch 1.7.0`
- Git hash: ``
- Commit date: ``
# Beachmark Results
## AISHELL-1
@ -73,3 +65,50 @@
|SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
|SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
|SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
# Fine-tuning Results
## Fine-tuning
- Train config:
- Training data: aishell-1
- Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
- Decoding info: beam_size 1, average_num 10
| model | dev cer(%) | test cer(%) |
|:---------:|:-------------:|:-------------:|
| Pretrain | 1.75 |1.95 |
| Finetune | 1.62 |1.78 |
- Train config:
- Training data: 16k sichuan dialect
- Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
- Decoding info: beam_size 1, average_num 10
| model | Training Data(h) | cn cer(%) | sichuan cer(%) |
|:--------:|:-------------:|:-------:|:------------:|
| Pretrain | | 8.57 | 19.81 |
| Finetune | 50 | 8.8 | 12 |
| | 100 | 9.24 | 11.63 |
| | 200 | 9.82 | 10.47 |
| | 300 | 9.95 | 10.44 |
| | 1000 | 9.99 | 9.78 |
## Lora Fine-tuning
- Train config:
- Training data: 16k sichuan dialect
- Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
- Lora info: lora_bias: "all", lora_list ['q','v'], lora_rank:8, lora_alpha:16, lora_dropout:0.1
- Decoding info: beam_size 1, average_num 10
| model | Training Data(h) | Trainable Parameters(M) | cn cer(%) | sichuan cer(%) |
|:-------------:|:----------------:|:-----------------------:|:---------:|:--------------:|
| Pretrain | | | 8.57 | 19.81 |
| | | | | |
| Finetune | 50 | 220.9 | 8.8 | 12 |
| Lora Finetune | 50 | 2.29 | 9.13 | 12.13 |
| | | | | |
| Finetune | 200 | 220.9 | 9.82 | 10.47 |
| Lora Finetune | 200 | 2.29 | 9.21 | 11.28 |

View File

@ -19,7 +19,8 @@ def modelscope_finetune(params):
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
lr=params.lr,
mate_params=params.param_dict)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
@ -30,7 +31,18 @@ if __name__ == '__main__':
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 50 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
params.max_epoch = 20 # 最大训练轮数
params.lr = 0.0002 # 设置学习率
init_param = [] # 初始模型路径默认加载modelscope模型初始化例如: ["checkpoint/20epoch.pb"]
freeze_param = [] # 模型参数freeze, 例如: ["encoder"]
ignore_init_mismatch = True # 是否忽略模型参数初始化不匹配
use_lora = False # 是否使用lora进行模型微调
params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
if use_lora:
enable_lora = True
lora_bias = "all"
lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
params.param_dict.update(lora_config)
modelscope_finetune(params)

View File

@ -2,7 +2,6 @@ import os
import yaml
def update_dct(fin_configs, root):
if root == {}:
return {}
@ -55,7 +54,7 @@ def build_trainer(modelscope_dict,
scheduler_conf=None,
specaug=None,
specaug_conf=None,
param_dict=None,
mate_params=None,
**kwargs):
mode = modelscope_dict['mode']
args, ASRTask = parse_args(mode=mode)
@ -92,6 +91,14 @@ def build_trainer(modelscope_dict,
for key, value in finetune_configs.items():
if hasattr(args, key):
setattr(args, key, value)
if mate_params is not None:
for key, value in mate_params.items():
if hasattr(args, key):
setattr(args, key, value)
if mate_params is not None and "lora_params" in mate_params:
lora_params = mate_params['lora_params']
configs['encoder_conf'].update(lora_params)
configs['decoder_conf'].update(lora_params)
# prepare data
args.dataset_type = dataset_type
@ -106,6 +113,9 @@ def build_trainer(modelscope_dict,
else:
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
args.init_param = [init_param]
if mate_params is not None and "init_param" in mate_params:
if len(mate_params["init_param"]) != 0:
args.init_param = mate_params["init_param"]
args.cmvn_file = cmvn_file
if os.path.exists(seg_dict_file):
args.seg_dict_file = seg_dict_file

View File

@ -28,6 +28,7 @@ from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@ -478,6 +479,18 @@ def get_parser():
default=None,
help="oss bucket.",
)
parser.add_argument(
"--enable_lora",
type=str2bool,
default=False,
help="Apply lora for finetuning.",
)
parser.add_argument(
"--lora_bias",
type=str,
default="none",
help="lora bias.",
)
return parser
@ -521,6 +534,8 @@ if __name__ == '__main__':
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
if args.enable_lora:
mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:

View File

@ -833,6 +833,10 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
lora_list: List[str] = None,
lora_rank: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
@ -885,7 +889,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,

View File

@ -146,6 +146,10 @@ class SANMEncoder(AbsEncoder):
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
lora_list: List[str] = None,
lora_rank: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
@ -229,6 +233,10 @@ class SANMEncoder(AbsEncoder):
attention_dropout_rate,
kernel_size,
sanm_shfit,
lora_list,
lora_rank,
lora_alpha,
lora_dropout,
)
encoder_selfattn_layer_args = (
@ -238,6 +246,10 @@ class SANMEncoder(AbsEncoder):
attention_dropout_rate,
kernel_size,
sanm_shfit,
lora_list,
lora_rank,
lora_alpha,
lora_dropout,
)
self.encoders0 = repeat(
1,

View File

@ -15,6 +15,7 @@ from typing import Optional, Tuple
import torch.nn.functional as F
from funasr.modules.nets_utils import make_pad_mask
import funasr.modules.lora.layers as lora
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@ -321,7 +322,7 @@ class MultiHeadedAttentionSANM(nn.Module):
"""
def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionSANM, self).__init__()
assert n_feat % n_head == 0
@ -331,8 +332,19 @@ class MultiHeadedAttentionSANM(nn.Module):
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
if lora_list is not None:
if "o" in lora_list:
self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
if lora_qkv_list == [False, False, False]:
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
else:
self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
@ -543,18 +555,32 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
"""
def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionCrossAtt, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
self.linear_out = nn.Linear(n_feat, n_feat)
if lora_list is not None:
if "q" in lora_list:
self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
else:
self.linear_q = nn.Linear(n_feat, n_feat)
lora_kv_list = ["k" in lora_list, "v" in lora_list]
if lora_kv_list == [False, False]:
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
else:
self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
if "o" in lora_list:
self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
else:
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)

View File

View File

@ -0,0 +1,323 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, List
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Embedding(nn.Embedding, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
merge_weights=merge_weights)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
self.merged = False
def eval(self):
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x, self.lora_A.T, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result += (after_A @ self.lora_B.T) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
class Linear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
assert out_features % len(enable_lora) == 0, \
'The length of enable_lora must divide out_features'
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features)))
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
) # weights for Conv1D with groups=sum(enable_lora)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features, ), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
after_A = F.linear(self.lora_dropout(x), self.lora_A)
after_B = F.conv1d(
after_A.transpose(-2, -1),
self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora)
).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
class Conv2d(nn.Conv2d, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
merge_weights: bool = True,
**kwargs
):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
assert type(kernel_size) is int
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
)
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Conv2d.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
nn.Conv2d.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = False
def eval(self):
nn.Conv2d.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
return F.conv2d(
x,
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
self.bias, self.stride, self.padding, self.dilation, self.groups
)
return nn.Conv2d.forward(self, x)

View File

@ -0,0 +1,50 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
from typing import Dict
from .layers import LoRALayer
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
for n, p in model.named_parameters():
if 'lora_' not in n and 'cif' not in n:
p.requires_grad = False
if bias == 'none':
return
elif bias == 'all':
for n, p in model.named_parameters():
if 'bias' in n:
p.requires_grad = True
elif bias == 'lora_only':
for m in model.modules():
if isinstance(m, LoRALayer) and \
hasattr(m, 'bias') and \
m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == 'none':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
elif bias == 'all':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
elif bias == 'lora_only':
to_return = {}
for k in my_state_dict:
if 'lora_' in k:
to_return[k] = my_state_dict[k]
bias_name = k.split('lora_')[0]+'bias'
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError

View File

@ -71,6 +71,7 @@ from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
import wandb
@ -952,6 +953,18 @@ class AbsTask(ABC):
default=None,
help="oss bucket.",
)
group.add_argument(
"--enable_lora",
type=str2bool,
default=False,
help="Apply lora for finetuning.",
)
group.add_argument(
"--lora_bias",
type=str,
default="none",
help="lora bias.",
)
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
@ -1246,6 +1259,8 @@ class AbsTask(ABC):
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
if args.enable_lora:
mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t: