mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* add cmakelist * add paraformer-torch * add debug for funasr-onnx-offline * fix redefinition of jieba StdExtension.hpp * add loading torch models * update funasr-onnx-offline * add SwitchArg for wss-server * add SwitchArg for funasr-onnx-offline * update cmakelist * update funasr-onnx-offline-rtf * add define condition * add gpu define for offlne-stream * update com define * update offline-stream * update cmakelist * update func CompileHotwordEmbedding * add timestamp for paraformer-torch * add C10_USE_GLOG for paraformer-torch * update paraformer-torch * fix func FunASRWfstDecoderInit * update model.h * fix func FunASRWfstDecoderInit * fix tpass_stream * update paraformer-torch * add bladedisc for funasr-onnx-offline * update comdefine * update funasr-wss-server * add log for torch * fix GetValue BLADEDISC * fix log * update cmakelist * update warmup to 10 * update funasrruntime * add batch_size for wss-server * add batch for bins * add batch for offline-stream * add batch for paraformer * add batch for offline-stream * fix func SetBatchSize * add SetBatchSize for model * add SetBatchSize for model * fix func Forward * fix padding * update funasrruntime * add dec reset for batch * set batch default value * add argv for CutSplit * sort frame_queue * sorted msgs * fix FunOfflineInfer * add dynamic batch for fetch * fix FetchDynamic * update run_server.sh * update run_server.sh * cpp http post server support (#1739) * add cpp http server * add some comment * remove some comments * del debug infos * restore run_server.sh * adapt to new model struct * 修复了onnxruntime在macos下编译失败的错误 (#1748) * Add files via upload 增加macos的编译支持 * Add files via upload 增加macos支持 * Add files via upload target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib) target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib) 添加 if(APPLE) 限制 --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> * Delete docs/images/wechat.png * Add files via upload * fixed the issues about seaco-onnx timestamp * fix bug (#1764) 当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url * fix empty asr result (#1765) 解码结果为空的语音片段,text 用空字符串 * docs * docs * docs * docs * docs * keep empty speech result (#1772) * docs * docs * update wechat QRcode * Add python funasr api support for websocket srv (#1777) * add python funasr_api supoort * change little to README.md * add core tools stream * modified a little * fix bug for timeout * support for buffer decode * add ffmpeg decode for buffer * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * Dev gzf exp (#1785) * resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * sensevoice sanm --------- Co-authored-by: 北念 <lzr265946@alibaba-inc.com> * auto frontend --------- Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com> Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com> Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com> Co-authored-by: Ephemeroptera <605686962@qq.com> Co-authored-by: 彭震东 <zhendong.peng@qq.com> Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
129 lines
4.4 KiB
Python
129 lines
4.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
|
|
from funasr.register import tables
|
|
|
|
|
|
@tables.register("adaptor_classes", "Linear")
|
|
class Linear(nn.Module):
|
|
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
|
|
super().__init__()
|
|
self.k = downsample_rate
|
|
self.encoder_dim = encoder_dim
|
|
self.llm_dim = llm_dim
|
|
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
|
|
self.relu = nn.ReLU()
|
|
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
|
|
|
|
def forward(self, x):
|
|
batch_size, seq_len, dim = x.size()
|
|
num_frames_to_discard = seq_len % self.k
|
|
if num_frames_to_discard > 0:
|
|
x = x[:, :-num_frames_to_discard, :]
|
|
seq_len = x.size(1)
|
|
|
|
x = x.contiguous()
|
|
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
|
x = self.linear1(x)
|
|
x = self.relu(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
|
|
@tables.register("adaptor_classes", "QFormer")
|
|
class EncoderProjectorQFormer(nn.Module):
|
|
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
|
|
super().__init__()
|
|
self.encoder_dim = encoder_dim
|
|
self.llm_dim = llm_dim
|
|
from transformers import Blip2QFormerConfig, Blip2QFormerModel
|
|
|
|
configuration = Blip2QFormerConfig()
|
|
configuration.encoder_hidden_size = self.encoder_dim
|
|
configuration.num_hidden_layers = 2
|
|
|
|
self.query_len = 64
|
|
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
|
|
self.query.data.normal_(mean=0.0, std=1.0)
|
|
self.qformer = Blip2QFormerModel(configuration)
|
|
|
|
self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
|
|
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
|
|
|
|
def forward(self, x, atts):
|
|
query = self.query.expand(x.shape[0], -1, -1)
|
|
|
|
query_output = self.qformer(
|
|
query_embeds=query,
|
|
encoder_hidden_states=x,
|
|
encoder_attention_mask=atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
query_proj = self.norm(self.linear(query_output.last_hidden_state))
|
|
|
|
return query_proj
|
|
|
|
|
|
@tables.register("adaptor_classes", "Transformer")
|
|
class Transformer(nn.Module):
|
|
def __init__(
|
|
self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
|
|
):
|
|
super().__init__()
|
|
self.k = downsample_rate
|
|
self.encoder_dim = encoder_dim
|
|
self.llm_dim = llm_dim
|
|
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
|
|
self.relu = nn.ReLU()
|
|
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
|
|
from funasr.models.transformer.encoder import EncoderLayer
|
|
from funasr.models.transformer.attention import MultiHeadedAttention
|
|
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
EncoderLayer(
|
|
llm_dim,
|
|
MultiHeadedAttention(
|
|
kwargs.get("attention_heads", 8),
|
|
llm_dim,
|
|
kwargs.get("attention_dropout_rate", 0.0),
|
|
),
|
|
PositionwiseFeedForward(
|
|
llm_dim,
|
|
llm_dim // 4,
|
|
kwargs.get("dropout_rate", 0.0),
|
|
),
|
|
kwargs.get("dropout_rate", 0.0),
|
|
)
|
|
for i in range(kwargs.get("n_layer", 2))
|
|
]
|
|
)
|
|
|
|
def forward(self, x, ilens=None):
|
|
|
|
batch_size, seq_len, dim = x.size()
|
|
# num_frames_to_discard = seq_len % self.k
|
|
chunk_num = (seq_len - 1) // self.k + 1
|
|
pad_num = chunk_num * self.k - seq_len
|
|
x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
|
|
# if num_frames_to_discard > 0:
|
|
# x = x[:, :-num_frames_to_discard, :]
|
|
seq_len = x.size(1)
|
|
|
|
x = x.contiguous()
|
|
x = x.view(batch_size, chunk_num, dim * self.k)
|
|
x = self.linear1(x)
|
|
x = self.relu(x)
|
|
x = self.linear2(x)
|
|
|
|
olens = None
|
|
olens = (ilens - 1) // self.k + 1
|
|
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
|
for layer, block in enumerate(self.blocks):
|
|
x, masks = block(x, masks)
|
|
return x, olens
|