FunASR/funasr/models/llm_asr/adaptor.py
zhifu gao 32e7836645
update with main (#1786)
* add cmakelist

* add paraformer-torch

* add debug for funasr-onnx-offline

* fix redefinition of jieba StdExtension.hpp

* add loading torch models

* update funasr-onnx-offline

* add SwitchArg for wss-server

* add SwitchArg for funasr-onnx-offline

* update cmakelist

* update funasr-onnx-offline-rtf

* add define condition

* add gpu define for offlne-stream

* update com define

* update offline-stream

* update cmakelist

* update func CompileHotwordEmbedding

* add timestamp for paraformer-torch

* add C10_USE_GLOG for paraformer-torch

* update paraformer-torch

* fix func FunASRWfstDecoderInit

* update model.h

* fix func FunASRWfstDecoderInit

* fix tpass_stream

* update paraformer-torch

* add bladedisc for funasr-onnx-offline

* update comdefine

* update funasr-wss-server

* add log for torch

* fix GetValue BLADEDISC

* fix log

* update cmakelist

* update warmup to 10

* update funasrruntime

* add batch_size for wss-server

* add batch for bins

* add batch for offline-stream

* add batch for paraformer

* add batch for offline-stream

* fix func SetBatchSize

* add SetBatchSize for model

* add SetBatchSize for model

* fix func Forward

* fix padding

* update funasrruntime

* add dec reset for batch

* set batch default value

* add argv for CutSplit

* sort frame_queue

* sorted msgs

* fix FunOfflineInfer

* add dynamic batch for fetch

* fix FetchDynamic

* update run_server.sh

* update run_server.sh

* cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments

* del debug infos

* restore run_server.sh

* adapt to new model struct

* 修复了onnxruntime在macos下编译失败的错误 (#1748)

* Add files via upload

增加macos的编译支持

* Add files via upload

增加macos支持

* Add files via upload

target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
添加 if(APPLE) 限制

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>

* Delete docs/images/wechat.png

* Add files via upload

* fixed the issues about seaco-onnx timestamp

* fix bug (#1764)

当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url

* fix empty asr result (#1765)

解码结果为空的语音片段,text 用空字符串

* docs

* docs

* docs

* docs

* docs

* keep empty speech result (#1772)

* docs

* docs

* update wechat QRcode

* Add python funasr api support for websocket srv (#1777)

* add python funasr_api supoort

* change little to README.md

* add core tools stream

* modified a little

* fix bug for timeout

* support for buffer decode

* add ffmpeg decode for buffer

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* Dev gzf exp (#1785)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* sensevoice sanm

---------

Co-authored-by: 北念 <lzr265946@alibaba-inc.com>

* auto frontend

---------

Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com>
Co-authored-by: Ephemeroptera <605686962@qq.com>
Co-authored-by: 彭震东 <zhendong.peng@qq.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
2024-06-06 09:54:35 +08:00

129 lines
4.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("adaptor_classes", "Linear")
class Linear(nn.Module):
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
super().__init__()
self.k = downsample_rate
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
def forward(self, x):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
@tables.register("adaptor_classes", "QFormer")
class EncoderProjectorQFormer(nn.Module):
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
super().__init__()
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
from transformers import Blip2QFormerConfig, Blip2QFormerModel
configuration = Blip2QFormerConfig()
configuration.encoder_hidden_size = self.encoder_dim
configuration.num_hidden_layers = 2
self.query_len = 64
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
self.query.data.normal_(mean=0.0, std=1.0)
self.qformer = Blip2QFormerModel(configuration)
self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
def forward(self, x, atts):
query = self.query.expand(x.shape[0], -1, -1)
query_output = self.qformer(
query_embeds=query,
encoder_hidden_states=x,
encoder_attention_mask=atts,
return_dict=True,
)
query_proj = self.norm(self.linear(query_output.last_hidden_state))
return query_proj
@tables.register("adaptor_classes", "Transformer")
class Transformer(nn.Module):
def __init__(
self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
):
super().__init__()
self.k = downsample_rate
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
from funasr.models.transformer.encoder import EncoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
self.blocks = nn.ModuleList(
[
EncoderLayer(
llm_dim,
MultiHeadedAttention(
kwargs.get("attention_heads", 8),
llm_dim,
kwargs.get("attention_dropout_rate", 0.0),
),
PositionwiseFeedForward(
llm_dim,
llm_dim // 4,
kwargs.get("dropout_rate", 0.0),
),
kwargs.get("dropout_rate", 0.0),
)
for i in range(kwargs.get("n_layer", 2))
]
)
def forward(self, x, ilens=None):
batch_size, seq_len, dim = x.size()
# num_frames_to_discard = seq_len % self.k
chunk_num = (seq_len - 1) // self.k + 1
pad_num = chunk_num * self.k - seq_len
x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
# if num_frames_to_discard > 0:
# x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, chunk_num, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
for layer, block in enumerate(self.blocks):
x, masks = block(x, masks)
return x, olens