update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)

* init

* update

* add LoadConfigFromYaml

* update

* update

* update

* del time stat

* update

* update

* update

* update

* update

* update

* update

* add cpp websocket online 2pass srv

* [feature] multithread grpc server

* update

* update

* update

* [feature] support 2pass grpc cpp server and python client, can change mode to use offline, online or 2pass decoding

* update

* update

* update

* update

* add paraformer online onnx model export

* add paraformer online onnx model export

* add paraformer online onnx model export

* add paraformer online onnxruntime

* add paraformer online onnxruntime

* add paraformer online onnxruntime

* fix export paraformer online onnx model bug

* for client closed earlier and core dump

* support GRPC two pass decoding (#813)

* [refator] optimize grpc server pipeline and instruction

* [refator] rm useless file

* [refator] optimize grpc client pipeline and instruction

* [debug] hanlde coredump when client ternimated

* [refator] rm useless log

* [refator] modify grpc cmake

* Create run_server_2pass.sh

* Update SDK_tutorial_online_zh.md

* Update SDK_tutorial_online.md

* Update SDK_advanced_guide_online.md

* Update SDK_advanced_guide_online_zh.md

* Update SDK_tutorial_online_zh.md

* Update SDK_tutorial_online.md

* update

---------

Co-authored-by: zhaoming <zhaomingwork@qq.com>
Co-authored-by: boji123 <boji123@aliyun.com>
Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
This commit is contained in:
Yabin Li 2023-08-08 11:17:43 +08:00 committed by GitHub
parent 57968c2180
commit b454a1054f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 5193 additions and 1194 deletions

3
.gitignore vendored
View File

@ -19,4 +19,5 @@ build
funasr.egg-info
docs/_build
modelscope
samples
samples
.ipynb_checkpoints

View File

@ -55,18 +55,21 @@ class ModelExport:
# export encoder1
self.export_config["model_name"] = "model"
model = get_model(
models = get_model(
model,
self.export_config,
)
model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
if not isinstance(models, tuple):
models = (models,)
for i, model in enumerate(models):
model.eval()
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):

View File

@ -1,4 +1,4 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
@ -10,10 +10,15 @@ from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, ParaformerOnline):
return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
ParaformerOnline_decoder_export(model, model_name="decoder"))
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
elif isinstance(model, Conformer_export):

View File

@ -157,3 +157,158 @@ class ParaformerSANMDecoder(nn.Module):
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
"odim": self.model.decoders[0].size
}
class ParaformerSANMDecoderOnline(nn.Module):
def __init__(self, model,
max_seq_len=512,
model_name='decoder',
onnx: bool = True, ):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
for i, d in enumerate(self.model.decoders):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
self.model.decoders[i] = DecoderLayerSANM_export(d)
if self.model.decoders2 is not None:
for i, d in enumerate(self.model.decoders2):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
self.model.decoders2[i] = DecoderLayerSANM_export(d)
for i, d in enumerate(self.model.decoders3):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
self.model.decoders3[i] = DecoderLayerSANM_export(d)
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
*args,
):
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
out_caches = list()
for i, decoder in enumerate(self.model.decoders):
in_cache = args[i]
x, tgt_mask, memory, memory_mask, out_cache = decoder(
x, tgt_mask, memory, memory_mask, cache=in_cache
)
out_caches.append(out_cache)
if self.model.decoders2 is not None:
for i, decoder in enumerate(self.model.decoders2):
in_cache = args[i+len(self.model.decoders)]
x, tgt_mask, memory, memory_mask, out_cache = decoder(
x, tgt_mask, memory, memory_mask, cache=in_cache
)
out_caches.append(out_cache)
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
x = self.after_norm(x)
x = self.output_layer(x)
return x, out_caches
def get_dummy_inputs(self, enc_size):
enc = torch.randn(2, 100, enc_size).type(torch.float32)
enc_len = torch.tensor([30, 100], dtype=torch.int32)
acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
cache_num = len(self.model.decoders)
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
cache = [
torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
for _ in range(cache_num)
]
return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
def get_input_names(self):
cache_num = len(self.model.decoders)
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
+ ['in_cache_%d' % i for i in range(cache_num)]
def get_output_names(self):
cache_num = len(self.model.decoders)
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
return ['logits', 'sample_ids'] \
+ ['out_cache_%d' % i for i in range(cache_num)]
def get_dynamic_axes(self):
ret = {
'enc': {
0: 'batch_size',
1: 'enc_length'
},
'acoustic_embeds': {
0: 'batch_size',
1: 'token_length'
},
'enc_len': {
0: 'batch_size',
},
'acoustic_embeds_len': {
0: 'batch_size',
},
}
cache_num = len(self.model.decoders)
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
ret.update({
'in_cache_%d' % d: {
0: 'batch_size',
}
for d in range(cache_num)
})
ret.update({
'out_cache_%d' % d: {
0: 'batch_size',
}
for d in range(cache_num)
})
return ret

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
@ -15,6 +15,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
class Paraformer(nn.Module):
@ -216,4 +217,150 @@ class BiCifParaformer(nn.Module):
0: 'batch_size',
1: 'alphas_length'
},
}
}
class ParaformerOnline_encoder_predictor(nn.Module):
"""
Author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
elif isinstance(model.encoder, ConformerEncoder):
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
if isinstance(model.predictor, CifPredictorV2):
self.predictor = CifPredictorV2_export(model.predictor)
self.feats_dim = feats_dim
self.model_name = model_name
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
# batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
alphas, _ = self.predictor.forward_cnn(enc, mask)
return enc, enc_len, alphas
def get_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'speech_lengths']
def get_output_names(self):
return ['enc', 'enc_len', 'alphas']
def get_dynamic_axes(self):
return {
'speech': {
0: 'batch_size',
1: 'feats_length'
},
'speech_lengths': {
0: 'batch_size',
},
'enc': {
0: 'batch_size',
1: 'feats_length'
},
'enc_len': {
0: 'batch_size',
},
'alphas': {
0: 'batch_size',
1: 'feats_length'
},
}
class ParaformerOnline_decoder(nn.Module):
"""
Author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.decoder, ParaformerDecoderSAN):
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
elif isinstance(model.decoder, ParaformerSANMDecoder):
self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
self.feats_dim = feats_dim
self.model_name = model_name
self.enc_size = model.encoder._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def forward(
self,
enc: torch.Tensor,
enc_len: torch.Tensor,
acoustic_embeds: torch.Tensor,
acoustic_embeds_len: torch.Tensor,
*args,
):
decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
sample_ids = decoder_out.argmax(dim=-1)
return decoder_out, sample_ids, out_caches
def get_dummy_inputs(self, ):
dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
return dummy_inputs
def get_input_names(self):
return self.decoder.get_input_names()
def get_output_names(self):
return self.decoder.get_output_names()
def get_dynamic_axes(self):
return self.decoder.get_dynamic_axes()

View File

@ -8,6 +8,7 @@ from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
from funasr.modules.embedding import StreamSinusoidalPositionEncoder
class SANMEncoder(nn.Module):
@ -21,6 +22,8 @@ class SANMEncoder(nn.Module):
):
super().__init__()
self.embed = model.embed
if isinstance(self.embed, StreamSinusoidalPositionEncoder):
self.embed = None
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
@ -63,8 +66,10 @@ class SANMEncoder(nn.Module):
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
online: bool = False
):
speech = speech * self._output_size ** 0.5
if not online:
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:

View File

@ -64,14 +64,14 @@ class MultiHeadedAttentionSANM(nn.Module):
return self.linear_out(context_layer) # (batch, time1, d_model)
def preprocess_for_attn(x, mask, cache, pad_fn):
def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
x = x * mask
x = x.transpose(1, 2)
if cache is None:
x = pad_fn(x)
else:
x = torch.cat((cache[:, :, 1:], x), dim=2)
cache = x
x = torch.cat((cache, x), dim=2)
cache = x[:, :, -(kernel_size-1):]
return x, cache
@ -90,7 +90,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
self.attn = None
def forward(self, inputs, mask, cache=None):
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
x = self.fsmn_block(x)
x = x.transpose(1, 2)

View File

@ -36,6 +36,17 @@ class CifPredictorV2(nn.Module):
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
alphas, token_num = self.forward_cnn(hidden, mask)
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def forward_cnn(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
@ -49,12 +60,8 @@ class CifPredictorV2(nn.Module):
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@ -285,4 +292,4 @@ def cif_wo_hidden(alphas, threshold: float):
integrate)
fires = torch.stack(list_fires, 1)
return fires
return fires

View File

@ -185,7 +185,9 @@ Introduction to command parameters:
--port: the port number of the server listener.
--wav-path: the audio input. Input can be a path to a wav file or a wav.scp file (a Kaldi-formatted wav list in which each line includes a wav_id followed by a tab and a wav_path).
--is-ssl: whether to use SSL encryption. The default is to use SSL.
--mode: offline mode.
--mode: 2pass.
--thread-num 1
```
### Custom client
@ -194,7 +196,9 @@ If you want to define your own client, the Websocket communication protocol is a
```text
# First communication
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}# Send wav data
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}
# Send wav data
Bytes data
# Send end flag
{"is_speaking": False}

View File

@ -76,7 +76,7 @@ Command parameter instructions:
After entering the samples/cpp directory, you can test it with CPP. The command is as follows:
```shell
./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
```
Command parameter description:

View File

@ -1,51 +1,44 @@
# Copyright 2018 gRPC authors.
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
# Reserved. MIT License (https://opensource.org/licenses/MIT)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cmake build file for C++ paraformer example.
# Assumes protobuf and gRPC have been installed using cmake.
# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
# that automatically builds all the dependencies before building paraformer.
# 2023 by burkliu(刘柏基) liubaiji@xverse.cn
cmake_minimum_required(VERSION 3.10)
project(ASR C CXX)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_VERBOSE_MAKEFILE on)
set(BUILD_TESTING OFF)
include(common.cmake)
# Proto file
get_filename_component(rg_proto "../python/grpc/proto/paraformer.proto" ABSOLUTE)
get_filename_component(rg_proto_path "${rg_proto}" PATH)
get_filename_component(rg_proto ../python/grpc/proto/paraformer.proto ABSOLUTE)
get_filename_component(rg_proto_path ${rg_proto} PATH)
# Generated sources
set(rg_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc")
set(rg_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h")
set(rg_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc")
set(rg_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h")
set(rg_proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc)
set(rg_proto_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h)
set(rg_grpc_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc)
set(rg_grpc_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h)
add_custom_command(
OUTPUT "${rg_proto_srcs}" "${rg_proto_hdrs}" "${rg_grpc_srcs}" "${rg_grpc_hdrs}"
COMMAND ${_PROTOBUF_PROTOC}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
-I "${rg_proto_path}"
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
"${rg_proto}"
DEPENDS "${rg_proto}")
OUTPUT ${rg_proto_srcs} ${rg_proto_hdrs} ${rg_grpc_srcs} ${rg_grpc_hdrs}
COMMAND ${_PROTOBUF_PROTOC}
ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR}
--cpp_out ${CMAKE_CURRENT_BINARY_DIR}
-I ${rg_proto_path}
--plugin=protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE}
${rg_proto}
DEPENDS ${rg_proto})
# Include generated *.pb.h files
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
include_directories(${CMAKE_CURRENT_BINARY_DIR})
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
@ -53,33 +46,21 @@ include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-nativ
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
add_subdirectory("../onnxruntime/src" onnx_src)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
# rg_grpc_proto
add_library(rg_grpc_proto
${rg_grpc_srcs}
${rg_grpc_hdrs}
${rg_proto_srcs}
${rg_proto_hdrs})
add_library(rg_grpc_proto ${rg_grpc_srcs} ${rg_grpc_hdrs} ${rg_proto_srcs} ${rg_proto_hdrs})
target_link_libraries(rg_grpc_proto
target_link_libraries(rg_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF})
add_executable(paraformer-server paraformer-server.cc)
target_link_libraries(paraformer-server
rg_grpc_proto
funasr
${EXTRA_LIBS}
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
foreach(_target
paraformer-server)
add_executable(${_target}
"${_target}.cc")
target_link_libraries(${_target}
rg_grpc_proto
funasr
${EXTRA_LIBS}
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
endforeach()

View File

@ -2,17 +2,20 @@
## For the Server
### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
### 1. Build [onnxruntime](../websocket/readme.md) as it's document
### Compile and install grpc v1.52.0 in case of grpc bugs
```
export GRPC_INSTALL_DIR=/data/soft/grpc
export PKG_CONFIG_PATH=$GRPC_INSTALL_DIR/lib/pkgconfig
### 2. Compile and install grpc v1.52.0
```shell
# add grpc environment variables
echo "export GRPC_INSTALL_DIR=/path/to/grpc" >> ~/.bashrc
echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
source ~/.bashrc
# install grpc
git clone --recurse-submodules -b v1.52.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc
git clone -b v1.52.0 --depth=1 https://github.com/grpc/grpc.git
cd grpc
git submodule update --init --recursive
mkdir -p cmake/build
pushd cmake/build
cmake -DgRPC_INSTALL=ON \
@ -22,182 +25,57 @@ cmake -DgRPC_INSTALL=ON \
make
make install
popd
echo "export GRPC_INSTALL_DIR=/data/soft/grpc" >> ~/.bashrc
echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
source ~/.bashrc
```
### Compile and start grpc onnx paraformer server
```
# set -DONNXRUNTIME_DIR=/path/to/asrmodel/onnxruntime-linux-x64-1.14.0
./rebuild.sh
### 3. Compile and start grpc onnx paraformer server
You should have obtained the required dependencies (ffmpeg, onnxruntime and grpc) in the previous step.
If no, run [download_ffmpeg](../onnxruntime/third_party/download_ffmpeg.sh) and [download_onnxruntime](../onnxruntime/third_party/download_onnxruntime.sh)
```shell
cd /cfs/user/burkliu/work2023/FunASR/funasr/runtime/grpc
./build.sh
```
### Start grpc paraformer server
```
### 4. Download paraformer model
To do.
### 5. Start grpc paraformer server
```shell
# run as default
./run_server.sh
# or run server directly
./build/bin/paraformer-server \
--port-id <string> \
--offline-model-dir <string> \
--online-model-dir <string> \
--quantize <string> \
--vad-dir <string> \
--vad-quant <string> \
--punc-dir <string> \
--punc-quant <string>
./cmake/build/paraformer-server --port-id <string> [--punc-quant <string>]
[--punc-dir <string>] [--vad-quant <string>]
[--vad-dir <string>] [--quantize <string>]
--model-dir <string> [--] [--version] [-h]
Where:
--port-id <string>
(required) port id
--model-dir <string>
(required) the asr model path, which contains model.onnx, config.yaml, am.mvn
--quantize <string>
false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
--port-id <string> (required) the port server listen to
--vad-dir <string>
the vad model path, which contains model.onnx, vad.yaml, vad.mvn
--vad-quant <string>
false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
--offline-model-dir <string> (required) the offline asr model path
--online-model-dir <string> (required) the online asr model path
--quantize <string> (optional) false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
--punc-dir <string>
the punc model path, which contains model.onnx, punc.yaml
--punc-quant <string>
false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
Required: --port-id <string> --model-dir <string>
If use vad, please add: --vad-dir <string>
If use punc, please add: --punc-dir <string>
--vad-dir <string> (required) the vad model path
--vad-quant <string> (optional) false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
--punc-dir <string> (required) the punc model path
--punc-quant <string> (optional) false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
```
## For the client
Currently we only support python grpc server.
### Install the requirements as in [grpc-python](./docs/grpc_python.md)
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/grpc
pip install -r requirements_client.txt
```
### Generate protobuf file
Run on server, the two generated pb files are both used for server and client
```shell
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
### Start grpc client
```
# Start client.
python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
```
[//]: # (```)
[//]: # (# go to ../python/grpc to find this package)
[//]: # (import paraformer_pb2)
[//]: # ()
[//]: # ()
[//]: # (class RecognizeStub:)
[//]: # ( def __init__&#40;self, channel&#41;:)
[//]: # ( self.Recognize = channel.stream_stream&#40;)
[//]: # ( '/paraformer.ASR/Recognize',)
[//]: # ( request_serializer=paraformer_pb2.Request.SerializeToString,)
[//]: # ( response_deserializer=paraformer_pb2.Response.FromString,)
[//]: # ( &#41;)
[//]: # ()
[//]: # ()
[//]: # (async def send&#40;channel, data, speaking, isEnd&#41;:)
[//]: # ( stub = RecognizeStub&#40;channel&#41;)
[//]: # ( req = paraformer_pb2.Request&#40;&#41;)
[//]: # ( if data:)
[//]: # ( req.audio_data = data)
[//]: # ( req.user = 'zz')
[//]: # ( req.language = 'zh-CN')
[//]: # ( req.speaking = speaking)
[//]: # ( req.isEnd = isEnd)
[//]: # ( q = queue.SimpleQueue&#40;&#41;)
[//]: # ( q.put&#40;req&#41;)
[//]: # ( return stub.Recognize&#40;iter&#40;q.get, None&#41;&#41;)
[//]: # ()
[//]: # (# send the audio data once)
[//]: # (async def grpc_rec&#40;data, grpc_uri&#41;:)
[//]: # ( with grpc.insecure_channel&#40;grpc_uri&#41; as channel:)
[//]: # ( b = time.time&#40;&#41;)
[//]: # ( response = await send&#40;channel, data, False, False&#41;)
[//]: # ( resp = response.next&#40;&#41;)
[//]: # ( text = '')
[//]: # ( if 'decoding' == resp.action:)
[//]: # ( resp = response.next&#40;&#41;)
[//]: # ( if 'finish' == resp.action:)
[//]: # ( text = json.loads&#40;resp.sentence&#41;['text'])
[//]: # ( response = await send&#40;channel, None, False, True&#41;)
[//]: # ( return {)
[//]: # ( 'text': text,)
[//]: # ( 'time': time.time&#40;&#41; - b,)
[//]: # ( })
[//]: # ()
[//]: # (async def test&#40;&#41;:)
[//]: # ( # fc = FunAsrGrpcClient&#40;'127.0.0.1', 9900&#41;)
[//]: # ( # t = await fc.rec&#40;wav.tobytes&#40;&#41;&#41;)
[//]: # ( # print&#40;t&#41;)
[//]: # ( wav, _ = sf.read&#40;'z-10s.wav', dtype='int16'&#41;)
[//]: # ( uri = '127.0.0.1:9900')
[//]: # ( res = await grpc_rec&#40;wav.tobytes&#40;&#41;, uri&#41;)
[//]: # ( print&#40;res&#41;)
[//]: # ()
[//]: # ()
[//]: # (if __name__ == '__main__':)
[//]: # ( asyncio.run&#40;test&#40;&#41;&#41;)
[//]: # ()
[//]: # (```)
Install the requirements as in [grpc-python](../python/grpc/Readme.md)
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
2. We acknowledge burkliu (刘柏基, liubaiji@xverse.cn) for contributing the grpc service.

15
funasr/runtime/grpc/build.sh Executable file
View File

@ -0,0 +1,15 @@
#!/bin/bash
mode=debug #[debug|release]
onnxruntime_dir=`pwd`/../onnxruntime/onnxruntime-linux-x64-1.14.0
ffmpeg_dir=`pwd`/../onnxruntime/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
rm build -rf
mkdir -p build
cd build
cmake -DCMAKE_BUILD_TYPE=$mode ../ -DONNXRUNTIME_DIR=$onnxruntime_dir -DFFMPEG_DIR=$ffmpeg_dir
cmake --build . -j 4
echo "Build server successfully!"

View File

@ -1,235 +1,261 @@
#include <algorithm>
#include <chrono>
#include <cmath>
#include <iostream>
#include <sstream>
#include <memory>
#include <string>
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
#include <grpc/grpc.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <grpcpp/server_context.h>
#include <grpcpp/security/server_credentials.h>
#include "paraformer.grpc.pb.h"
#include "paraformer-server.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "glog/logging.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::ServerReader;
using grpc::ServerReaderWriter;
using grpc::ServerWriter;
using grpc::Status;
GrpcEngine::GrpcEngine(
grpc::ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<FUNASR_HANDLE> asr_handler)
: stream_(std::move(stream)),
asr_handler_(std::move(asr_handler)) {
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
AsrHanlde=FunOfflineInit(model_path, 1);
std::cout << "ASRServicer init" << std::endl;
init_flag = 0;
request_ = std::make_shared<Request>();
}
void ASRServicer::clear_states(const std::string& user) {
clear_buffers(user);
clear_transcriptions(user);
}
void GrpcEngine::DecodeThreadFunc() {
FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
std::vector<std::vector<std::string>> punc_cache(2);
void ASRServicer::clear_buffers(const std::string& user) {
if (client_buffers.count(user)) {
client_buffers.erase(user);
}
}
bool is_final = false;
std::string online_result = "";
std::string tpass_result = "";
void ASRServicer::clear_transcriptions(const std::string& user) {
if (client_transcription.count(user)) {
client_transcription.erase(user);
}
}
LOG(INFO) << "Decoder init, start decoding loop with mode";
void ASRServicer::disconnect(const std::string& user) {
clear_states(user);
std::cout << "Disconnecting user: " << user << std::endl;
}
while (true) {
if (audio_buffer_.length() > step || is_end_) {
if (audio_buffer_.length() <= step && is_end_) {
is_final = true;
step = audio_buffer_.length();
}
grpc::Status ASRServicer::Recognize(
grpc::ServerContext* context,
grpc::ServerReaderWriter<Response, Request>* stream) {
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
tpass_online_handler,
audio_buffer_.c_str(),
step,
punc_cache,
is_final,
sampling_rate_,
encoding_,
mode_);
audio_buffer_ = audio_buffer_.substr(step);
Request req;
while (stream->Read(&req)) {
if (req.isend()) {
std::cout << "asr end" << std::endl;
disconnect(req.user());
Response res;
res.set_sentence(
R"({"success": true, "detail": "asr end"})"
);
res.set_user(req.user());
res.set_action("terminate");
res.set_language(req.language());
stream->Write(res);
} else if (req.speaking()) {
if (req.audio_data().size() > 0) {
auto& buf = client_buffers[req.user()];
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
}
Response res;
res.set_sentence(
R"({"success": true, "detail": "speaking"})"
);
res.set_user(req.user());
res.set_action("speaking");
res.set_language(req.language());
stream->Write(res);
} else if (!req.speaking()) {
if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
Response res;
res.set_sentence(
R"({"success": true, "detail": "waiting_for_voice"})"
);
res.set_user(req.user());
res.set_action("waiting");
res.set_language(req.language());
stream->Write(res);
}else {
auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
if (req.audio_data().size() > 0) {
auto& buf = client_buffers[req.user()];
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
}
std::string tmp_data = this->client_buffers[req.user()];
this->clear_states(req.user());
Response res;
res.set_sentence(
R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
);
int data_len_int = tmp_data.length();
std::string data_len = std::to_string(data_len_int);
std::stringstream ss;
ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
std::string result = ss.str();
res.set_sentence(result);
res.set_user(req.user());
res.set_action("decoding");
res.set_language(req.language());
stream->Write(res);
if (tmp_data.length() < 800) { //min input_len for asr model
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
std::string delay_str = std::to_string(end_time - begin_time);
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", error: data_is_not_long_enough" << std::endl;
Response res;
std::stringstream ss;
std::string asr_result = "";
ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
std::string result = ss.str();
res.set_sentence(result);
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
stream->Write(res);
}
else {
FUNASR_RESULT Result= FunOfflineInferBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000);
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
std::string delay_str = std::to_string(end_time - begin_time);
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
Response res;
std::stringstream ss;
ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
std::string result = ss.str();
res.set_sentence(result);
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
stream->Write(res);
}
}
}else {
Response res;
res.set_sentence(
R"({"success": false, "detail": "error, no condition matched! Unknown reason."})"
);
res.set_user(req.user());
res.set_action("terminate");
res.set_language(req.language());
stream->Write(res);
if (result) {
std::string online_message = FunASRGetResult(result, 0);
online_result += online_message;
if(online_message != ""){
Response response;
response.set_mode(DecodeMode::online);
response.set_text(online_message);
response.set_is_final(is_final);
stream_->Write(response);
LOG(INFO) << "send online results: " << online_message;
}
std::string tpass_message = FunASRGetTpassResult(result, 0);
tpass_result += tpass_message;
if(tpass_message != ""){
Response response;
response.set_mode(DecodeMode::two_pass);
response.set_text(tpass_message);
response.set_is_final(is_final);
stream_->Write(response);
LOG(INFO) << "send offline results: " << tpass_message;
}
FunASRFreeResult(result);
}
if (is_final) {
FunTpassOnlineUninit(tpass_online_handler);
break;
}
}
return Status::OK;
sleep(0.001);
}
}
void RunServer(std::map<std::string, std::string>& model_path) {
std::string port;
try{
port = model_path.at(PORT_ID);
}catch(std::exception const &e){
printf("Error when read port.\n");
exit(0);
void GrpcEngine::OnSpeechStart() {
if (request_->chunk_size_size() == 3) {
for (int i = 0; i < 3; i++) {
chunk_size_[i] = int(request_->chunk_size(i));
}
std::string server_address;
server_address = "0.0.0.0:" + port;
ASRServicer service(model_path);
}
std::string chunk_size_str;
for (int i = 0; i < 3; i++) {
chunk_size_str = " " + chunk_size_[i];
}
LOG(INFO) << "chunk_size is" << chunk_size_str;
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
if (request_->sampling_rate() != 0) {
sampling_rate_ = request_->sampling_rate();
}
LOG(INFO) << "sampling_rate is " << sampling_rate_;
switch(request_->wav_format()) {
case WavFormat::pcm: encoding_ = "pcm";
}
LOG(INFO) << "encoding is " << encoding_;
std::string mode_str;
switch(request_->mode()) {
case DecodeMode::offline:
mode_ = ASR_OFFLINE;
mode_str = "offline";
break;
case DecodeMode::online:
mode_ = ASR_ONLINE;
mode_str = "online";
break;
case DecodeMode::two_pass:
mode_ = ASR_TWO_PASS;
mode_str = "two_pass";
break;
}
LOG(INFO) << "decode mode is " << mode_str;
decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
is_start_ = true;
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
{
if (value_arg.isSet()){
model_path.insert({key, value_arg.getValue()});
LOG(INFO)<< key << " : " << value_arg.getValue();
void GrpcEngine::OnSpeechData() {
audio_buffer_ += request_->audio_data();
}
void GrpcEngine::OnSpeechEnd() {
is_end_ = true;
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
}
void GrpcEngine::operator()() {
try {
LOG(INFO) << "start engine main loop";
while (stream_->Read(request_.get())) {
LOG(INFO) << "receive data";
if (!is_start_) {
OnSpeechStart();
}
OnSpeechData();
if (request_->is_final()) {
break;
}
}
OnSpeechEnd();
LOG(INFO) << "Connect finish";
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
: config_(config) {
asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
LOG(INFO) << "GrpcService model loaded";
std::vector<int> chunk_size = {5, 10, 5};
FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
int sampling_rate = 16000;
int buffer_len = sampling_rate * 1;
std::string tmp_data(buffer_len, '0');
std::vector<std::vector<std::string>> punc_cache(2);
bool is_final = true;
std::string encoding = "pcm";
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
tmp_online_handler,
tmp_data.c_str(),
buffer_len,
punc_cache,
is_final,
buffer_len,
encoding,
ASR_TWO_PASS);
if (result) {
FunASRFreeResult(result);
}
FunTpassOnlineUninit(tmp_online_handler);
LOG(INFO) << "GrpcService model warmup";
}
grpc::Status GrpcService::Recognize(
grpc::ServerContext* context,
grpc::ServerReaderWriter<Response, Request>* stream) {
LOG(INFO) << "Get Recognize request";
GrpcEngine engine(
stream,
asr_handler_
);
std::thread t(std::move(engine));
t.join();
return grpc::Status::OK;
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
if (value_arg.isSet()) {
config.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
}
}
int main(int argc, char* argv[]) {
FLAGS_logtostderr = true;
google::InitGoogleLogging(argv[0]);
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
cmd.add(offline_model_dir);
cmd.add(online_model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(onnx_thread);
cmd.add(port_id);
cmd.parse(argc, argv);
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(port_id);
cmd.parse(argc, argv);
std::map<std::string, std::string> config;
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config);
GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
GetValue(quantize, QUANTIZE, config);
GetValue(vad_dir, VAD_DIR, config);
GetValue(vad_quant, VAD_QUANT, config);
GetValue(punc_dir, PUNC_DIR, config);
GetValue(punc_quant, PUNC_QUANT, config);
GetValue(port_id, PORT_ID, config);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(port_id, PORT_ID, model_path);
std::string port;
try {
port = config.at(PORT_ID);
} catch(std::exception const &e) {
LOG(INFO) << ("Error when read port.");
exit(0);
}
std::string server_address;
server_address = "0.0.0.0:" + port;
GrpcService service(config, onnx_thread);
RunServer(model_path);
return 0;
grpc::ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
return 0;
}

View File

@ -1,55 +1,65 @@
#include <algorithm>
#include <chrono>
#include <cmath>
#include <iostream>
#include <memory>
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
#include <string>
#include <thread>
#include <unistd.h>
#include <grpc/grpc.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <grpcpp/server_context.h>
#include <grpcpp/security/server_credentials.h>
#include <unordered_map>
#include <chrono>
#include "grpcpp/server_builder.h"
#include "paraformer.grpc.pb.h"
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "glog/logging.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::ServerReader;
using grpc::ServerReaderWriter;
using grpc::ServerWriter;
using grpc::Status;
using paraformer::WavFormat;
using paraformer::DecodeMode;
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
typedef struct
{
std::string msg;
float snippet_time;
}FUNASR_RECOG_RESULT;
std::string msg;
float snippet_time;
} FUNASR_RECOG_RESULT;
class ASRServicer final : public ASR::Service {
private:
int init_flag;
std::unordered_map<std::string, std::string> client_buffers;
std::unordered_map<std::string, std::string> client_transcription;
class GrpcEngine {
public:
GrpcEngine(grpc::ServerReaderWriter<Response, Request>* stream, std::shared_ptr<FUNASR_HANDLE> asr_handler);
void operator()();
public:
ASRServicer(std::map<std::string, std::string>& model_path);
void clear_states(const std::string& user);
void clear_buffers(const std::string& user);
void clear_transcriptions(const std::string& user);
void disconnect(const std::string& user);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
FUNASR_HANDLE AsrHanlde;
private:
void DecodeThreadFunc();
void OnSpeechStart();
void OnSpeechData();
void OnSpeechEnd();
grpc::ServerReaderWriter<Response, Request>* stream_;
std::shared_ptr<Request> request_;
std::shared_ptr<Response> response_;
std::shared_ptr<FUNASR_HANDLE> asr_handler_;
std::string audio_buffer_;
std::shared_ptr<std::thread> decode_thread_ = nullptr;
bool is_start_ = false;
bool is_end_ = false;
std::vector<int> chunk_size_ = {5, 10, 5};
int sampling_rate_ = 16000;
std::string encoding_;
ASR_TYPE mode_ = ASR_TWO_PASS;
int step_duration_ms_ = 100;
};
class GrpcService final : public ASR::Service {
public:
GrpcService(std::map<std::string, std::string>& config, int num_thread);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
private:
std::map<std::string, std::string> config_;
std::shared_ptr<FUNASR_HANDLE> asr_handler_;
};

View File

@ -1,12 +0,0 @@
#!/bin/bash
rm cmake -rf
mkdir -p cmake/build
cd cmake/build
cmake -DCMAKE_BUILD_TYPE=release ../.. -DONNXRUNTIME_DIR=/data/asrmodel/onnxruntime-linux-x64-1.14.0
make
echo "Build cmake/build/paraformer_server successfully!"

View File

@ -0,0 +1,12 @@
#!/bin/bash
./build/bin/paraformer-server \
--port-id 10100 \
--offline-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
--online-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
--quantize true \
--vad-dir funasr_models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
--vad-quant true \
--punc-dir funasr_models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727 \
--punc-quant true \
2>&1

View File

@ -9,6 +9,9 @@ target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
@ -17,3 +20,16 @@ target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
# include_directories(${FFMPEG_DIR}/include)
# add_executable(ff "ffmpeg.cpp")
# target_link_libraries(ff PUBLIC avutil avcodec avformat swresample)

View File

@ -0,0 +1,310 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <atomic>
#include <mutex>
#include <thread>
#include <glog/logging.h>
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "audio.h"
using namespace std;
std::atomic<int> wav_index(0);
std::mutex mtx;
bool is_target_file(const std::string& filename, const std::string target) {
std::size_t pos = filename.find_last_of(".");
if (pos == std::string::npos) {
return false;
}
std::string extension = filename.substr(pos + 1);
return (extension == target);
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
model_path.insert({key, value_arg.getValue()});
LOG(INFO)<< key << " : " << value_arg.getValue();
}
void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids,
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_) {
struct timeval start, end;
long seconds = 0;
float n_total_length = 0.0f;
long n_total_time = 0;
// init online features
FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
// warm up
for (size_t i = 0; i < 2; i++)
{
int32_t sampling_rate_ = 16000;
funasr::Audio audio(1);
if(is_target_file(wav_list[0].c_str(), "wav")){
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}else if(is_target_file(wav_list[0].c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 1600*2;
bool is_final = false;
std::vector<std::vector<string>> punc_cache(2);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
if (result)
{
FunASRFreeResult(result);
}
}
}
while (true) {
// 使用原子变量获取索引并递增
int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
int32_t sampling_rate_ = 16000;
funasr::Audio audio(1);
if(is_target_file(wav_list[i].c_str(), "wav")){
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}else if(is_target_file(wav_list[i].c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 1600*2;
bool is_final = false;
string online_res="";
string tpass_res="";
std::vector<std::vector<string>> punc_cache(2);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
n_total_time += taking_micros;
if (result)
{
string online_msg = FunASRGetResult(result, 0);
online_res += online_msg;
if(online_msg != ""){
LOG(INFO)<< wav_ids[i] <<" : "<<online_msg;
}
string tpass_msg = FunASRGetTpassResult(result, 0);
tpass_res += tpass_msg;
if(tpass_msg != ""){
LOG(INFO)<< wav_ids[i] <<" offline results : "<<tpass_msg;
}
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
FunASRFreeResult(result);
}
else
{
LOG(ERROR) << ("No return data!\n");
}
}
if(asr_mode_ == 2){
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<online_res;
}
if(asr_mode_==1){
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<tpass_res;
}
if(asr_mode_ == 0 || asr_mode_==2){
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final offline results " <<" : "<<tpass_res;
}
}
{
lock_guard<mutex> guard(mtx);
*total_length += n_total_length;
if(*total_time < n_total_time){
*total_time = n_total_time;
}
}
FunTpassOnlineUninit(tpass_online_handle);
}
int main(int argc, char** argv)
{
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
cmd.add(offline_model_dir);
cmd.add(online_model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(wav_path);
cmd.add(asr_mode);
cmd.add(onnx_thread);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(asr_mode, ASR_MODE, model_path);
struct timeval start, end;
gettimeofday(&start, NULL);
int thread_num = onnx_thread.getValue();
int asr_mode_ = -1;
if(model_path[ASR_MODE] == "offline"){
asr_mode_ = 0;
}else if(model_path[ASR_MODE] == "online"){
asr_mode_ = 1;
}else if(model_path[ASR_MODE] == "2pass"){
asr_mode_ = 2;
}else{
LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
exit(-1);
}
FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num);
if (!tpass_hanlde)
{
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
// read wav_path
vector<string> wav_list;
vector<string> wav_ids;
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
if (!in.is_open()) {
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
}else{
wav_list.emplace_back(wav_path_);
wav_ids.emplace_back(default_id);
}
std::vector<int> chunk_size = {5,10,5};
// 多线程测试
float total_length = 0.0f;
long total_time = 0;
std::vector<std::thread> threads;
int rtf_threds = 5;
for (int i = 0; i < rtf_threds; i++)
{
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, &total_length, &total_time, i, (ASR_TYPE)asr_mode_));
}
for (auto& thread : threads)
{
thread.join();
}
LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
FunTpassUninit(tpass_hanlde);
return 0;
}

View File

@ -0,0 +1,217 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <glog/logging.h>
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
std::size_t pos = filename.find_last_of(".");
if (pos == std::string::npos) {
return false;
}
std::string extension = filename.substr(pos + 1);
return (extension == target);
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
model_path.insert({key, value_arg.getValue()});
LOG(INFO)<< key << " : " << value_arg.getValue();
}
int main(int argc, char** argv)
{
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
cmd.add(offline_model_dir);
cmd.add(online_model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(wav_path);
cmd.add(asr_mode);
cmd.add(onnx_thread);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(asr_mode, ASR_MODE, model_path);
struct timeval start, end;
gettimeofday(&start, NULL);
int thread_num = onnx_thread.getValue();
int asr_mode_ = -1;
if(model_path[ASR_MODE] == "offline"){
asr_mode_ = 0;
}else if(model_path[ASR_MODE] == "online"){
asr_mode_ = 1;
}else if(model_path[ASR_MODE] == "2pass"){
asr_mode_ = 2;
}else{
LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
exit(-1);
}
FUNASR_HANDLE tpass_handle=FunTpassInit(model_path, thread_num);
if (!tpass_handle)
{
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
// read wav_path
vector<string> wav_list;
vector<string> wav_ids;
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
if (!in.is_open()) {
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
}else{
wav_list.emplace_back(wav_path_);
wav_ids.emplace_back(default_id);
}
// init online features
std::vector<int> chunk_size = {5,10,5};
FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
float snippet_time = 0.0f;
long taking_micros = 0;
for (int i = 0; i < wav_list.size(); i++) {
auto& wav_file = wav_list[i];
auto& wav_id = wav_ids[i];
int32_t sampling_rate_ = 16000;
funasr::Audio audio(1);
if(is_target_file(wav_file.c_str(), "wav")){
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}else if(is_target_file(wav_file.c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_file.c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 1600*2;
bool is_final = false;
string online_res="";
string tpass_res="";
std::vector<std::vector<string>> punc_cache(2);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
if (result)
{
string online_msg = FunASRGetResult(result, 0);
online_res += online_msg;
if(online_msg != ""){
LOG(INFO)<< wav_id <<" : "<<online_msg;
}
string tpass_msg = FunASRGetTpassResult(result, 0);
tpass_res += tpass_msg;
if(tpass_msg != ""){
LOG(INFO)<< wav_id <<" offline results : "<<tpass_msg;
}
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
}
if(asr_mode_==2){
LOG(INFO) << wav_id << " Final online results "<<" : "<<online_res;
}
if(asr_mode_==1){
LOG(INFO) << wav_id << " Final online results "<<" : "<<tpass_res;
}
if(asr_mode_==0 || asr_mode_==2){
LOG(INFO) << wav_id << " Final offline results " <<" : "<<tpass_res;
}
}
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
FunTpassOnlineUninit(tpass_online_handle);
FunTpassUninit(tpass_handle);
return 0;
}

View File

@ -40,6 +40,9 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
for (size_t i = 0; i < 1; i++)
{
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
if(result){
FunASRFreeResult(result);
}
}
while (true) {

View File

@ -0,0 +1,174 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <vector>
#include <glog/logging.h>
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
std::size_t pos = filename.find_last_of(".");
if (pos == std::string::npos) {
return false;
}
std::string extension = filename.substr(pos + 1);
return (extension == target);
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
if (value_arg.isSet()){
model_path.insert({key, value_arg.getValue()});
LOG(INFO)<< key << " : " << value_arg.getValue();
}
}
int main(int argc, char *argv[])
{
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(wav_path);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(wav_path, WAV_PATH, model_path);
struct timeval start, end;
gettimeofday(&start, NULL);
int thread_num = 1;
FUNASR_HANDLE asr_handle=FunASRInit(model_path, thread_num, ASR_ONLINE);
if (!asr_handle)
{
LOG(ERROR) << "FunVad init failed";
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
// read wav_path
vector<string> wav_list;
vector<string> wav_ids;
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
if (!in.is_open()) {
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
}else{
wav_list.emplace_back(wav_path_);
wav_ids.emplace_back(default_id);
}
// init online features
FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
float snippet_time = 0.0f;
long taking_micros = 0;
for (int i = 0; i < wav_list.size(); i++) {
auto& wav_file = wav_list[i];
auto& wav_id = wav_ids[i];
int32_t sampling_rate_ = -1;
funasr::Audio audio(1);
if(is_target_file(wav_file.c_str(), "wav")){
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}else if(is_target_file(wav_file.c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_file.c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_file;
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 9600*2;
bool is_final = false;
string final_res="";
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
if (result)
{
string msg = FunASRGetResult(result, 0);
final_res += msg;
LOG(INFO)<< wav_id <<" : "<<msg;
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
else
{
LOG(ERROR) << ("No return data!\n");
}
}
LOG(INFO)<<"Final results " << wav_id <<" : "<<final_res;
}
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
FunASRUninit(asr_handle);
FunASRUninit(online_handle);
return 0;
}

View File

@ -0,0 +1,278 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <glog/logging.h>
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <atomic>
#include <mutex>
#include <thread>
#include <map>
#include "audio.h"
using namespace std;
std::atomic<int> wav_index(0);
std::mutex mtx;
bool is_target_file(const std::string& filename, const std::string target) {
std::size_t pos = filename.find_last_of(".");
if (pos == std::string::npos) {
return false;
}
std::string extension = filename.substr(pos + 1);
return (extension == target);
}
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
float* total_length, long* total_time, int core_id) {
struct timeval start, end;
long seconds = 0;
float n_total_length = 0.0f;
long n_total_time = 0;
// init online features
FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
// warm up
for (size_t i = 0; i < 10; i++)
{
int32_t sampling_rate_ = -1;
funasr::Audio audio(1);
if(is_target_file(wav_list[0].c_str(), "wav")){
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}else if(is_target_file(wav_list[0].c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_list[0];
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 9600*2;
bool is_final = false;
string final_res="";
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
if (result)
{
FunASRFreeResult(result);
}
}
}
while (true) {
// 使用原子变量获取索引并递增
int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
int32_t sampling_rate_ = -1;
funasr::Audio audio(1);
if(is_target_file(wav_list[i].c_str(), "wav")){
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}else if(is_target_file(wav_list[i].c_str(), "pcm")){
if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}else{
if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
LOG(ERROR)<<"Failed to load "<< wav_list[i];
exit(-1);
}
}
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 9600*2;
bool is_final = false;
string final_res="";
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
} else {
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
n_total_time += taking_micros;
if (result)
{
string msg = FunASRGetResult(result, 0);
final_res += msg;
LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
FunASRFreeResult(result);
}
else
{
LOG(ERROR) << ("No return data!\n");
}
}
LOG(INFO) << "Thread: " << this_thread::get_id() << ", Final results " << wav_ids[i] << " : " << final_res;
}
{
lock_guard<mutex> guard(mtx);
*total_length += n_total_length;
if(*total_time < n_total_time){
*total_time = n_total_time;
}
}
FunASRUninit(online_handle);
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
if (value_arg.isSet()){
model_path.insert({key, value_arg.getValue()});
LOG(INFO)<< key << " : " << value_arg.getValue();
}
}
int main(int argc, char *argv[])
{
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-online-rtf", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(wav_path);
cmd.add(thread_num);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(wav_path, WAV_PATH, model_path);
struct timeval start, end;
gettimeofday(&start, NULL);
FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1, ASR_ONLINE);
if (!asr_handle)
{
LOG(ERROR) << "FunASR init failed";
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
// read wav_path
vector<string> wav_list;
vector<string> wav_ids;
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
wav_list.emplace_back(wav_path_);
wav_ids.emplace_back(default_id);
}
else if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
if (!in.is_open()) {
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
}else{
LOG(ERROR)<<"Please check the wav extension!";
exit(-1);
}
// 多线程测试
float total_length = 0.0f;
long total_time = 0;
std::vector<std::thread> threads;
int rtf_threds = thread_num.getValue();
for (int i = 0; i < rtf_threds; i++)
{
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
}
for (auto& thread : threads)
{
thread.join();
}
LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
FunASRUninit(asr_handle);
return 0;
}

View File

@ -159,7 +159,7 @@ int main(int argc, char *argv[])
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
int step = 3200;
int step = 800*2;
bool is_final = false;
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {

View File

@ -5,6 +5,7 @@
#include <stdint.h>
#include "vad-model.h"
#include "offline-stream.h"
#include "com-define.h"
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
@ -17,11 +18,13 @@ class AudioFrame {
private:
int start;
int end;
int len;
public:
AudioFrame();
AudioFrame(int len);
AudioFrame(const AudioFrame &other);
AudioFrame(int start, int end, bool is_final);
~AudioFrame();
int SetStart(int val);
@ -29,6 +32,10 @@ class AudioFrame {
int GetStart();
int GetLen();
int Disp();
// 2pass
bool is_final = false;
float* data = nullptr;
int len;
};
class Audio {
@ -38,10 +45,11 @@ class Audio {
char* speech_char=nullptr;
int speech_len;
int speech_align_len;
int offset;
float align_size;
int data_type;
queue<AudioFrame *> frame_queue;
queue<AudioFrame *> asr_online_queue;
queue<AudioFrame *> asr_offline_queue;
public:
Audio(int data_type);
@ -56,17 +64,35 @@ class Audio {
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
bool LoadOthers2Char(const char* filename);
bool FfmpegLoad(const char *filename);
bool FfmpegLoad(const char *filename, bool copy2char=false);
bool FfmpegLoad(const char* buf, int n_file_len);
int FetchChunck(float *&dout, int len);
int FetchChunck(AudioFrame *&frame);
int FetchTpass(AudioFrame *&frame);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
void Split(OfflineStream* offline_streamj);
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
char* GetSpeechChar(){return speech_char;}
int GetSpeechLen(){return speech_len;}
// 2pass
vector<float> all_samples;
int offset = 0;
int speech_start=-1, speech_end=0;
int speech_offline_start=-1;
int seg_sample = MODEL_SAMPLE_RATE/1000;
bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate);
void ResetIndex(){
speech_start=-1;
speech_end=0;
speech_offline_start=-1;
offset = 0;
all_samples.clear();
}
};
} // namespace funasr

View File

@ -13,11 +13,14 @@ namespace funasr {
// parser option
#define MODEL_DIR "model-dir"
#define OFFLINE_MODEL_DIR "model-dir"
#define ONLINE_MODEL_DIR "online-model-dir"
#define VAD_DIR "vad-dir"
#define PUNC_DIR "punc-dir"
#define QUANTIZE "quantize"
#define VAD_QUANT "vad-quant"
#define PUNC_QUANT "punc-quant"
#define ASR_MODE "mode"
#define WAV_PATH "wav-path"
#define WAV_SCP "wav-scp"
@ -42,6 +45,11 @@ namespace funasr {
#define AM_CONFIG_NAME "config.yaml"
#define PUNC_CONFIG_NAME "punc.yaml"
#define ENCODER_NAME "model.onnx"
#define QUANT_ENCODER_NAME "model_quant.onnx"
#define DECODER_NAME "decoder.onnx"
#define QUANT_DECODER_NAME "decoder_quant.onnx"
// vad
#ifndef VAD_SILENCE_DURATION
#define VAD_SILENCE_DURATION 800
@ -63,6 +71,19 @@ namespace funasr {
#define VAD_LFR_N 1
#endif
// asr
#ifndef PARA_LFR_M
#define PARA_LFR_M 7
#endif
#ifndef PARA_LFR_N
#define PARA_LFR_N 6
#endif
#ifndef ONLINE_STEP
#define ONLINE_STEP 9600
#endif
// punc
#define UNK_CHAR "<unk>"
#define TOKEN_LEN 20

View File

@ -46,21 +46,29 @@ typedef enum {
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef enum {
ASR_OFFLINE=0,
ASR_ONLINE=1,
ASR_TWO_PASS=2,
}ASR_TYPE;
typedef enum {
PUNC_OFFLINE=0,
PUNC_ONLINE=1,
}PUNC_TYPE;
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// ASR
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type=ASR_OFFLINE);
_FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_handle, std::vector<int> chunk_size={5,10,5});
// buffer
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
@ -94,6 +102,14 @@ _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char*
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
//2passStream
_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size={5,10,5});
// buffer
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS);
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
#ifdef __cplusplus
}

View File

@ -4,17 +4,21 @@
#include <string>
#include <map>
#include "funasrruntime.h"
namespace funasr {
class Model {
public:
virtual ~Model(){};
virtual void Reset() = 0;
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
virtual std::string Forward(float *din, int len, int flag) = 0;
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual std::string Forward(float *din, int len, bool input_finished){return "";};
virtual std::string Rescoring() = 0;
};
Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
Model *CreateModel(void* asr_handle, std::vector<int> chunk_size);
} // namespace funasr
#endif

View File

@ -14,9 +14,9 @@ class OfflineStream {
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
~OfflineStream(){};
std::unique_ptr<VadModel> vad_handle;
std::unique_ptr<Model> asr_handle;
std::unique_ptr<PuncModel> punc_handle;
std::unique_ptr<VadModel> vad_handle= nullptr;
std::unique_ptr<Model> asr_handle= nullptr;
std::unique_ptr<PuncModel> punc_handle= nullptr;
bool UseVad(){return use_vad;};
bool UsePunc(){return use_punc;};

View File

@ -0,0 +1,20 @@
#ifndef TPASS_ONLINE_STREAM_H
#define TPASS_ONLINE_STREAM_H
#include <memory>
#include "tpass-stream.h"
#include "model.h"
#include "vad-model.h"
namespace funasr {
class TpassOnlineStream {
public:
TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size);
~TpassOnlineStream(){};
std::unique_ptr<VadModel> vad_online_handle = nullptr;
std::unique_ptr<Model> asr_online_handle = nullptr;
};
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size);
} // namespace funasr
#endif

View File

@ -0,0 +1,31 @@
#ifndef TPASS_STREAM_H
#define TPASS_STREAM_H
#include <memory>
#include <string>
#include <map>
#include "model.h"
#include "punc-model.h"
#include "vad-model.h"
namespace funasr {
class TpassStream {
public:
TpassStream(std::map<std::string, std::string>& model_path, int thread_num);
~TpassStream(){};
// std::unique_ptr<VadModel> vad_handle = nullptr;
std::unique_ptr<VadModel> vad_handle = nullptr;
std::unique_ptr<Model> asr_handle = nullptr;
std::unique_ptr<PuncModel> punc_online_handle = nullptr;
bool UseVad(){return use_vad;};
bool UsePunc(){return use_punc;};
private:
bool use_vad=false;
bool use_punc=false;
};
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
} // namespace funasr
#endif

View File

@ -132,40 +132,54 @@ class AudioWindow {
};
};
AudioFrame::AudioFrame(){};
AudioFrame::AudioFrame(){}
AudioFrame::AudioFrame(int len) : len(len)
{
start = 0;
};
AudioFrame::~AudioFrame(){};
}
AudioFrame::AudioFrame(const AudioFrame &other)
{
start = other.start;
end = other.end;
len = other.len;
is_final = other.is_final;
}
AudioFrame::AudioFrame(int start, int end, bool is_final):start(start),end(end),is_final(is_final){
len = end - start;
}
AudioFrame::~AudioFrame(){
if(data != NULL){
free(data);
}
}
int AudioFrame::SetStart(int val)
{
start = val < 0 ? 0 : val;
return start;
};
}
int AudioFrame::SetEnd(int val)
{
end = val;
len = end - start;
return end;
};
}
int AudioFrame::GetStart()
{
return start;
};
}
int AudioFrame::GetLen()
{
return len;
};
}
int AudioFrame::Disp()
{
LOG(ERROR) << "Not imp!!!!";
return 0;
};
}
Audio::Audio(int data_type) : data_type(data_type)
{
@ -230,7 +244,7 @@ void Audio::WavResample(int32_t sampling_rate, const float *waveform,
copy(samples.begin(), samples.end(), speech_data);
}
bool Audio::FfmpegLoad(const char *filename){
bool Audio::FfmpegLoad(const char *filename, bool copy2char){
// from file
AVFormatContext* formatContext = avformat_alloc_context();
if (avformat_open_input(&formatContext, filename, NULL, NULL) != 0) {
@ -353,8 +367,17 @@ bool Audio::FfmpegLoad(const char *filename){
if (speech_buff != NULL) {
free(speech_buff);
}
if (speech_char != NULL) {
free(speech_char);
}
offset = 0;
if(copy2char){
speech_char = (char *)malloc(resampled_buffers.size());
memset(speech_char, 0, resampled_buffers.size());
memcpy((void*)speech_char, (const void*)resampled_buffers.data(), resampled_buffers.size());
}
speech_len = (resampled_buffers.size()) / 2;
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
@ -762,6 +785,55 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
return false;
}
bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
if (speech_char != NULL) {
free(speech_char);
}
speech_len = n_buf_len / 2;
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
float scale = 1;
if (data_type == 1) {
scale = 32768;
}
for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
WavResample(*sampling_rate, speech_data, speech_len);
}
for (int32_t i = 0; i != speech_len; ++i) {
all_samples.emplace_back(speech_data[i]);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
return true;
}
else
return false;
}
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
{
if (speech_data != NULL) {
@ -870,24 +942,25 @@ bool Audio::LoadOthers2Char(const char* filename)
return true;
}
int Audio::FetchChunck(float *&dout, int len)
int Audio::FetchTpass(AudioFrame *&frame)
{
if (offset >= speech_align_len) {
dout = NULL;
return S_ERR;
} else if (offset == speech_align_len - len) {
dout = speech_data + offset;
offset = speech_align_len;
// 临时解决
AudioFrame *frame = frame_queue.front();
frame_queue.pop();
delete frame;
return S_END;
if (asr_offline_queue.size() > 0) {
frame = asr_offline_queue.front();
asr_offline_queue.pop();
return 1;
} else {
dout = speech_data + offset;
offset += len;
return S_MIDDLE;
return 0;
}
}
int Audio::FetchChunck(AudioFrame *&frame)
{
if (asr_online_queue.size() > 0) {
frame = asr_online_queue.front();
asr_online_queue.pop();
return 1;
} else {
return 0;
}
}
@ -956,7 +1029,6 @@ void Audio::Split(OfflineStream* offline_stream)
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
int seg_sample = MODEL_SAMPLE_RATE/1000;
for(vector<int> segment:vad_segments)
{
frame = new AudioFrame();
@ -969,7 +1041,6 @@ void Audio::Split(OfflineStream* offline_stream)
}
}
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
{
AudioFrame *frame;
@ -984,4 +1055,161 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, boo
vad_segments = vad_obj->Infer(pcm_data, input_finished);
}
// 2pass
void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYPE asr_mode)
{
AudioFrame *frame;
frame = frame_queue.front();
frame_queue.pop();
int sp_len = frame->GetLen();
delete frame;
frame = NULL;
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
vector<std::vector<int>> vad_segments = vad_obj->Infer(pcm_data, input_finished);
speech_end += sp_len/seg_sample;
if(vad_segments.size() == 0){
if(speech_start != -1){
int start = speech_start*seg_sample;
int end = speech_end*seg_sample;
int buff_len = end-start;
int step = chunk_len;
if(asr_mode != ASR_OFFLINE){
if(buff_len >= step){
frame = new AudioFrame(step);
frame->data = (float*)malloc(sizeof(float) * step);
memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
asr_online_queue.push(frame);
frame = NULL;
speech_start += step/seg_sample;
}
}
}
}else{
for(auto vad_segment: vad_segments){
int speech_start_i=-1, speech_end_i=-1;
if(vad_segment[0] != -1){
speech_start_i = vad_segment[0];
}
if(vad_segment[1] != -1){
speech_end_i = vad_segment[1];
}
// [1, 100]
if(speech_start_i != -1 && speech_end_i != -1){
int start = speech_start_i*seg_sample;
int end = speech_end_i*seg_sample;
if(asr_mode != ASR_OFFLINE){
frame = new AudioFrame(end-start);
frame->is_final = true;
frame->data = (float*)malloc(sizeof(float) * (end-start));
memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
asr_online_queue.push(frame);
frame = NULL;
}
if(asr_mode != ASR_ONLINE){
frame = new AudioFrame(end-start);
frame->is_final = true;
frame->data = (float*)malloc(sizeof(float) * (end-start));
memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
asr_offline_queue.push(frame);
frame = NULL;
}
speech_start = -1;
speech_offline_start = -1;
// [70, -1]
}else if(speech_start_i != -1){
speech_start = speech_start_i;
speech_offline_start = speech_start_i;
int start = speech_start*seg_sample;
int end = speech_end*seg_sample;
int buff_len = end-start;
int step = chunk_len;
if(asr_mode != ASR_OFFLINE){
if(buff_len >= step){
frame = new AudioFrame(step);
frame->data = (float*)malloc(sizeof(float) * step);
memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
asr_online_queue.push(frame);
frame = NULL;
speech_start += step/seg_sample;
}
}
}else if(speech_end_i != -1){ // [-1,100]
if(speech_start == -1 or speech_offline_start == -1){
LOG(ERROR) <<"Vad start is null while vad end is available." ;
exit(-1);
}
int start = speech_start*seg_sample;
int offline_start = speech_offline_start*seg_sample;
int end = speech_end_i*seg_sample;
int buff_len = end-start;
int step = chunk_len;
if(asr_mode != ASR_ONLINE){
frame = new AudioFrame(end-offline_start);
frame->is_final = true;
frame->data = (float*)malloc(sizeof(float) * (end-offline_start));
memcpy(frame->data, all_samples.data()+offline_start-offset, (end-offline_start)*sizeof(float));
asr_offline_queue.push(frame);
frame = NULL;
}
if(asr_mode != ASR_OFFLINE){
if(buff_len > 0){
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
bool is_final = false;
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
is_final = true;
}
frame = new AudioFrame(step);
frame->is_final = is_final;
frame->data = (float*)malloc(sizeof(float) * step);
memcpy(frame->data, all_samples.data()+start-offset+sample_offset, step*sizeof(float));
asr_online_queue.push(frame);
frame = NULL;
}
}else{
frame = new AudioFrame(0);
frame->is_final = true;
asr_online_queue.push(frame);
frame = NULL;
}
}
speech_start = -1;
speech_offline_start = -1;
}
}
}
// erase all_samples
int vector_cache = MODEL_SAMPLE_RATE*2;
if(speech_offline_start == -1){
if(all_samples.size() > vector_cache){
int erase_num = all_samples.size() - vector_cache;
all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
offset += erase_num;
}
}else{
int offline_start = speech_offline_start*seg_sample;
if(offline_start-offset > vector_cache){
int erase_num = offline_start-offset - vector_cache;
all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
offset += erase_num;
}
}
}
} // namespace funasr

View File

@ -5,7 +5,8 @@ namespace funasr {
typedef struct
{
std::string msg;
float snippet_time;
std::string tpass_msg;
float snippet_time;
}FUNASR_RECOG_RESULT;
typedef struct

View File

@ -181,11 +181,12 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
//vad_mask
vector<float> arVadMask,arSubMask;
// vector<float> arVadMask,arSubMask;
vector<float> arVadMask;
int nTextLength = input_data.size();
VadMask(nTextLength, nCacheSize, arVadMask);
Triangle(nTextLength, arSubMask);
// Triangle(nTextLength, arSubMask);
std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
m_memoryInfo,
@ -198,8 +199,8 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
m_memoryInfo,
arSubMask.data(),
arSubMask.size() ,
arVadMask.data(),
arVadMask.size(),
SubMask_Dim.data(),
SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);

View File

@ -55,7 +55,7 @@ void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
reserve_waveforms_.clear();
reserve_waveforms_.insert(reserve_waveforms_.begin(),
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
@ -86,7 +86,7 @@ void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &
int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
vector<vector<float>> out_feats;
int T = vad_feats.size();
int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
int lfr_splice_frame_idxs = T_lrf;
vector<float> p;
for (int i = 0; i < T_lrf; i++) {
@ -175,6 +175,9 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
vad_silence_duration_ = vad_silence_duration;
vad_max_len_ = vad_max_len;
vad_speech_noise_thres_ = vad_speech_noise_thres;
// 2pass
audio_handle = make_unique<Audio>(1);
}
FsmnVadOnline::~FsmnVadOnline() {

View File

@ -21,6 +21,8 @@ public:
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
void Reset();
// 2pass
std::unique_ptr<Audio> audio_handle = nullptr;
private:
E2EVadModel vad_scorer = E2EVadModel();

View File

@ -5,9 +5,15 @@ extern "C" {
#endif
// APIs for Init
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
{
funasr::Model* mm = funasr::CreateModel(model_path, thread_num);
funasr::Model* mm = funasr::CreateModel(model_path, thread_num, type);
return mm;
}
_FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_hanlde, std::vector<int> chunk_size)
{
funasr::Model* mm = funasr::CreateModel(asr_hanlde, chunk_size);
return mm;
}
@ -35,8 +41,19 @@ extern "C" {
return mm;
}
_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num)
{
funasr::TpassStream* mm = funasr::CreateTpassStream(model_path, thread_num);
return mm;
}
_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size)
{
return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
}
// APIs for ASR Infer
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
{
funasr::Model* recog_obj = (funasr::Model*)handle;
if (!recog_obj)
@ -57,12 +74,12 @@ extern "C" {
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
return p_result;
}
return p_result;
}
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
string msg = recog_obj->Forward(buff, len, input_finished);
p_result->msg += msg;
n_step++;
if (fn_callback)
@ -102,7 +119,7 @@ extern "C" {
return p_result;
}
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
string msg = recog_obj->Forward(buff, len, true);
p_result->msg += msg;
n_step++;
if (fn_callback)
@ -230,7 +247,7 @@ extern "C" {
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
p_result->msg += msg;
n_step++;
if (fn_callback)
@ -277,7 +294,7 @@ extern "C" {
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
p_result->msg+= msg;
n_step++;
if (fn_callback)
@ -291,6 +308,91 @@ extern "C" {
return p_result;
}
// APIs for 2pass-stream Infer
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
if (!tpass_stream || !tpass_online_stream)
return nullptr;
funasr::VadModel* vad_online_handle = (tpass_online_stream->vad_online_handle).get();
if (!vad_online_handle)
return nullptr;
funasr::Audio* audio = ((funasr::FsmnVadOnline*)vad_online_handle)->audio_handle.get();
funasr::Model* asr_online_handle = (tpass_online_stream->asr_online_handle).get();
if (!asr_online_handle)
return nullptr;
int chunk_len = ((funasr::ParaformerOnline*)asr_online_handle)->chunk_len;
funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
if (!asr_handle)
return nullptr;
funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get();
if (!punc_online_handle)
return nullptr;
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio->LoadPcmwavOnline(sz_buf, n_len, &sampling_rate))
return nullptr;
}else{
// if (!audio->FfmpegLoad(sz_buf, n_len))
// return nullptr;
LOG(ERROR) <<"Wrong wav_format: " << wav_format ;
exit(-1);
}
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio->GetTimeLen();
if(p_result->snippet_time == 0){
return p_result;
}
audio->Split(vad_online_handle, chunk_len, input_finished, mode);
funasr::AudioFrame* frame = NULL;
while(audio->FetchChunck(frame) > 0){
string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
if(mode == ASR_ONLINE){
((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
if(frame->is_final){
string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
p_result->tpass_msg = msg_punc;
((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
p_result->msg += msg;
}else{
p_result->msg += msg;
}
}else if(mode == ASR_TWO_PASS){
p_result->msg += msg;
}
if(frame != NULL){
delete frame;
frame = NULL;
}
}
while(audio->FetchTpass(frame) > 0){
string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
p_result->tpass_msg = msg_punc;
if(frame != NULL){
delete frame;
frame = NULL;
}
}
if(input_finished){
audio->ResetIndex();
}
return p_result;
}
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
{
if (!result)
@ -326,6 +428,15 @@ extern "C" {
return p_result->msg.c_str();
}
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
if(!p_result)
return nullptr;
return p_result->tpass_msg.c_str();
}
_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
@ -414,6 +525,26 @@ extern "C" {
delete offline_stream;
}
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
if (!tpass_stream)
return;
delete tpass_stream;
}
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle)
{
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)handle;
if (!tpass_online_stream)
return;
delete tpass_online_stream;
}
#ifdef __cplusplus
}

View File

@ -1,22 +1,55 @@
#include "precomp.h"
namespace funasr {
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
{
string am_model_path;
string am_cmvn_path;
string am_config_path;
// offline
if(type == ASR_OFFLINE){
string am_model_path;
string am_cmvn_path;
string am_config_path;
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
Model *mm;
mm = new Paraformer();
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
return mm;
}else if(type == ASR_ONLINE){
// online
string en_model_path;
string de_model_path;
string am_cmvn_path;
string am_config_path;
en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
en_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_ENCODER_NAME);
de_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_DECODER_NAME);
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
Model *mm;
mm = new Paraformer();
mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
return mm;
}else{
LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
exit(-1);
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
}
Model *mm;
mm = new Paraformer();
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
Model *CreateModel(void* asr_handle, std::vector<int> chunk_size)
{
Model* mm;
mm = new ParaformerOnline((Paraformer*)asr_handle, chunk_size);
return mm;
}

View File

@ -0,0 +1,551 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
using namespace std;
namespace funasr {
ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
InitOnline(
para_handle_->fbank_opts_,
para_handle_->encoder_session_,
para_handle_->decoder_session_,
para_handle_->en_szInputNames_,
para_handle_->en_szOutputNames_,
para_handle_->de_szInputNames_,
para_handle_->de_szOutputNames_,
para_handle_->means_list_,
para_handle_->vars_list_);
InitCache();
}
void ParaformerOnline::InitOnline(
knf::FbankOptions &fbank_opts,
std::shared_ptr<Ort::Session> &encoder_session,
std::shared_ptr<Ort::Session> &decoder_session,
vector<const char*> &en_szInputNames,
vector<const char*> &en_szOutputNames,
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
vector<float> &vars_list){
fbank_opts_ = fbank_opts;
encoder_session_ = encoder_session;
decoder_session_ = decoder_session;
en_szInputNames_ = en_szInputNames;
en_szOutputNames_ = en_szOutputNames;
de_szInputNames_ = de_szInputNames;
de_szOutputNames_ = de_szOutputNames;
means_list_ = means_list;
vars_list_ = vars_list;
frame_length = para_handle_->frame_length;
frame_shift = para_handle_->frame_shift;
n_mels = para_handle_->n_mels;
lfr_m = para_handle_->lfr_m;
lfr_n = para_handle_->lfr_n;
encoder_size = para_handle_->encoder_size;
fsmn_layers = para_handle_->fsmn_layers;
fsmn_lorder = para_handle_->fsmn_lorder;
fsmn_dims = para_handle_->fsmn_dims;
cif_threshold = para_handle_->cif_threshold;
tail_alphas = para_handle_->tail_alphas;
// other vars
sqrt_factor = std::sqrt(encoder_size);
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
fsmn_init_cache_.emplace_back(0);
}
chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
}
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
std::vector<float> &waves) {
knf::OnlineFbank fbank(fbank_opts_);
// cache merge
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
// Send the audio after the last frame shift position to the cache
input_cache_.clear();
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
if (frame_number == 0) {
return;
}
// Delete audio that haven't undergone fbank processing
waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
std::vector<float> buf(waves.size());
for (int32_t i = 0; i != waves.size(); ++i) {
buf[i] = waves[i] * 32768;
}
fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
wav_feats.emplace_back(frame_vector);
}
}
void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
vector<float> &waves, bool input_finished) {
FbankKaldi(sample_rate, wav_feats, waves);
// cache deal & online lfr,cmvn
if (wav_feats.size() > 0) {
if (!reserve_waveforms_.empty()) {
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
}
if (lfr_splice_cache_.empty()) {
for (int i = 0; i < (lfr_m - 1) / 2; i++) {
lfr_splice_cache_.emplace_back(wav_feats[0]);
}
}
if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
reserve_waveforms_.clear();
reserve_waveforms_.insert(reserve_waveforms_.begin(),
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
waves.begin() + frame_from_waves * frame_shift_sample_length_);
int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
waves.erase(waves.begin() + sample_length, waves.end());
} else {
reserve_waveforms_.clear();
reserve_waveforms_.insert(reserve_waveforms_.begin(),
waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
}
} else {
if (input_finished) {
if (!reserve_waveforms_.empty()) {
waves = reserve_waveforms_;
}
wav_feats = lfr_splice_cache_;
OnlineLfrCmvn(wav_feats, input_finished);
}
}
if(input_finished){
ResetCache();
}
}
int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
vector<vector<float>> out_feats;
int T = wav_feats.size();
int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
int lfr_splice_frame_idxs = T_lrf;
vector<float> p;
for (int i = 0; i < T_lrf; i++) {
if (lfr_m <= T - i * lfr_n) {
for (int j = 0; j < lfr_m; j++) {
p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
}
out_feats.emplace_back(p);
p.clear();
} else {
if (input_finished) {
int num_padding = lfr_m - (T - i * lfr_n);
for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
}
for (int j = 0; j < num_padding; j++) {
p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
}
out_feats.emplace_back(p);
} else {
lfr_splice_frame_idxs = i;
break;
}
}
}
lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
lfr_splice_cache_.clear();
lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
// Apply cmvn
for (auto &out_feat: out_feats) {
for (int j = 0; j < means_list_.size(); j++) {
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
wav_feats = out_feats;
return lfr_splice_frame_idxs;
}
void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
{
int start_idx = start_idx_cache_;
start_idx_cache_ += timesteps;
int mm = start_idx_cache_;
int i;
float scale = -0.0330119726594128;
std::vector<float> tmp(mm*feat_dim);
for (i = 0; i < feat_dim/2; i++) {
float tmptime = exp(i * scale);
int j;
for (j = 0; j < mm; j++) {
int sin_idx = j * feat_dim + i;
int cos_idx = j * feat_dim + i + feat_dim/2;
float coe = tmptime * (j + 1);
tmp[sin_idx] = sin(coe);
tmp[cos_idx] = cos(coe);
}
}
for (i = start_idx; i < start_idx + timesteps; i++) {
for (int j = 0; j < feat_dim; j++) {
wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
}
}
}
void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
{
try{
int hidden_size = 0;
if(hidden.size() > 0){
hidden_size = hidden[0].size();
}
// cache
int i,j;
int chunk_size_pre = chunk_size[0];
for (i = 0; i < chunk_size_pre; i++)
alphas[i] = 0.0;
int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
for (i = chunk_size_suf; i < alphas.size(); i++){
alphas[i] = 0.0;
}
if(hidden_cache_.size()>0){
hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
hidden_cache_.clear();
alphas_cache_.clear();
}
if (is_last_chunk) {
std::vector<float> tail_hidden(hidden_size, 0);
hidden.emplace_back(tail_hidden);
alphas.emplace_back(tail_alphas);
}
float intergrate = 0.0;
int len_time = alphas.size();
std::vector<float> frames(hidden_size, 0);
std::vector<float> list_fire;
for (i = 0; i < len_time; i++) {
float alpha = alphas[i];
if (alpha + intergrate < cif_threshold) {
intergrate += alpha;
list_fire.emplace_back(intergrate);
for (j = 0; j < hidden_size; j++) {
frames[j] += alpha * hidden[i][j];
}
} else {
for (j = 0; j < hidden_size; j++) {
frames[j] += (cif_threshold - intergrate) * hidden[i][j];
}
std::vector<float> frames_cp(frames);
list_frame.emplace_back(frames_cp);
intergrate += alpha;
list_fire.emplace_back(intergrate);
intergrate -= cif_threshold;
for (j = 0; j < hidden_size; j++) {
frames[j] = intergrate * hidden[i][j];
}
}
}
// cache
alphas_cache_.emplace_back(intergrate);
if (intergrate > 0.0) {
std::vector<float> hidden_cache(hidden_size, 0);
for (i = 0; i < hidden_size; i++) {
hidden_cache[i] = frames[i] / intergrate;
}
hidden_cache_.emplace_back(hidden_cache);
} else {
std::vector<float> frames_cp(frames);
hidden_cache_.emplace_back(frames_cp);
}
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
}
}
void ParaformerOnline::InitCache(){
start_idx_cache_ = 0;
is_first_chunk = true;
is_last_chunk = false;
hidden_cache_.clear();
alphas_cache_.clear();
feats_cache_.clear();
decoder_onnx.clear();
// cif cache
std::vector<float> hidden_cache(encoder_size, 0);
hidden_cache_.emplace_back(hidden_cache);
alphas_cache_.emplace_back(0);
// feats
std::vector<float> feat_cache(feat_dims, 0);
for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
feats_cache_.emplace_back(feat_cache);
}
// fsmn cache
#ifdef _WIN_X86
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
for(int l=0; l<fsmn_layers; l++){
Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
m_memoryInfo,
fsmn_init_cache_.data(),
fsmn_init_cache_.size(),
fsmn_shape_,
3);
decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
}
};
void ParaformerOnline::Reset()
{
InitCache();
}
void ParaformerOnline::ResetCache() {
reserve_waveforms_.clear();
input_cache_.clear();
lfr_splice_cache_.clear();
}
void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
if(input_finished){
feats_cache_.clear();
feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
if(!is_last_chunk){
int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
std::vector<float> tmp(feat_dims, 0);
for(int i=0; i<padding_length; i++){
wav_feats.emplace_back(feat_dims);
}
}
}else{
feats_cache_.clear();
feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());
}
}
string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
{
string result;
try{
int32_t num_frames = chunk_feats.size();
#ifdef _WIN_X86
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
const int64_t input_shape_[3] = {1, num_frames, feat_dims};
std::vector<float> wav_feats;
for (const auto &chunk_feat: chunk_feats) {
wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
}
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
m_memoryInfo,
wav_feats.data(),
wav_feats.size(),
input_shape_,
3);
const int64_t paraformer_length_shape[1] = {1};
std::vector<int32_t> paraformer_length;
paraformer_length.emplace_back(num_frames);
Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
std::vector<Ort::Value> input_onnx;
input_onnx.emplace_back(std::move(onnx_feats));
input_onnx.emplace_back(std::move(onnx_feats_len));
auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
// get enc_vec
std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
for (int i = 0; i < enc_shape[1]; i++) {
for (int j = 0; j < enc_shape[2]; j++) {
enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
}
}
// get alpha_vec
std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
std::vector<float> alpha_vec(alpha_shape[1]);
for (int i = 0; i < alpha_shape[1]; i++) {
alpha_vec[i] = alpha_data[i];
}
std::vector<std::vector<float>> list_frame;
CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
if(list_frame.size()>0){
// enc
decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
// enc_lens
decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
// acoustic_embeds
const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
std::vector<float> emb_input;
for (const auto &list_frame_: list_frame) {
emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
}
Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
m_memoryInfo,
emb_input.data(),
emb_input.size(),
emb_shape_,
3);
decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
// acoustic_embeds_len
const int64_t emb_length_shape[1] = {1};
std::vector<int32_t> emb_length;
emb_length.emplace_back(list_frame.size());
Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
// fsmn cache
try{
decoder_onnx.clear();
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
return result;
}
for(int l=0;l<fsmn_layers;l++){
decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
}
std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
}
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
return result;
}
return result;
}
string ParaformerOnline::Forward(float* din, int len, bool input_finished)
{
std::vector<std::vector<float>> wav_feats;
std::vector<float> waves(din, din+len);
string result="";
try{
if(len <16*60 && input_finished && !is_first_chunk){
is_last_chunk = true;
wav_feats = feats_cache_;
result = ForwardChunk(wav_feats, is_last_chunk);
// reset
ResetCache();
Reset();
return result;
}
if(is_first_chunk){
is_first_chunk = false;
}
ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
if(wav_feats.size() == 0){
return result;
}
for (auto& row : wav_feats) {
for (auto& val : row) {
val *= sqrt_factor;
}
}
GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
if(input_finished){
if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
is_last_chunk = true;
AddOverlapChunk(wav_feats, input_finished);
}else{
// first chunk
std::vector<std::vector<float>> first_chunk;
first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
AddOverlapChunk(first_chunk, input_finished);
string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
// last chunk
is_last_chunk = true;
std::vector<std::vector<float>> last_chunk;
last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
AddOverlapChunk(last_chunk, input_finished);
string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
result = str_first_chunk+str_last_chunk;
// reset
ResetCache();
Reset();
return result;
}
}else{
AddOverlapChunk(wav_feats, input_finished);
}
result = ForwardChunk(wav_feats, is_last_chunk);
if(input_finished){
// reset
ResetCache();
Reset();
}
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
return result;
}
return result;
}
ParaformerOnline::~ParaformerOnline()
{
}
string ParaformerOnline::Rescoring()
{
LOG(ERROR)<<"Not Imp!!!!!!";
return "";
}
} // namespace funasr

View File

@ -0,0 +1,111 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#pragma once
#include "precomp.h"
namespace funasr {
class ParaformerOnline : public Model {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* ParaformerOnline: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
* https://arxiv.org/pdf/2206.08317.pdf
*/
private:
void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
std::vector<float> &waves);
int OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished);
void GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim);
void CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>> &list_frame);
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
if (frame_num >= 1 && sample_length >= frame_sample_length)
return frame_num;
else
return 0;
}
void InitOnline(
knf::FbankOptions &fbank_opts,
std::shared_ptr<Ort::Session> &encoder_session,
std::shared_ptr<Ort::Session> &decoder_session,
vector<const char*> &en_szInputNames,
vector<const char*> &en_szOutputNames,
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
vector<float> &vars_list);
Paraformer* para_handle_ = nullptr;
// from para_handle_
knf::FbankOptions fbank_opts_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
Ort::SessionOptions session_options_;
vector<const char*> en_szInputNames_;
vector<const char*> en_szOutputNames_;
vector<const char*> de_szInputNames_;
vector<const char*> de_szOutputNames_;
vector<float> means_list_;
vector<float> vars_list_;
// configs from para_handle_
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
int lfr_m = PARA_LFR_M;
int lfr_n = PARA_LFR_N;
int encoder_size = 512;
int fsmn_layers = 16;
int fsmn_lorder = 10;
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;
// configs
int feat_dims = lfr_m*n_mels;
std::vector<int> chunk_size = {5,10,5};
int frame_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_length;
int frame_shift_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_shift;
// The reserved waveforms by fbank
std::vector<float> reserve_waveforms_;
// waveforms reserved after last shift position
std::vector<float> input_cache_;
// lfr reserved cache
std::vector<std::vector<float>> lfr_splice_cache_;
// position index cache
int start_idx_cache_ = 0;
// cif alpha
std::vector<float> alphas_cache_;
std::vector<std::vector<float>> hidden_cache_;
std::vector<std::vector<float>> feats_cache_;
// fsmn init caches
std::vector<float> fsmn_init_cache_;
std::vector<Ort::Value> decoder_onnx;
bool is_first_chunk = true;
bool is_last_chunk = false;
double sqrt_factor;
public:
ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
~ParaformerOnline();
void Reset();
void ResetCache();
void InitCache();
void ExtractFeats(float sample_rate, vector<vector<float>> &wav_feats, vector<float> &waves, bool input_finished);
void AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
string Forward(float* din, int len, bool input_finished);
string Rescoring();
// 2pass
std::string online_res;
int chunk_len;
};
} // namespace funasr

View File

@ -10,29 +10,30 @@ using namespace std;
namespace funasr {
Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}{
}
// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
fbank_opts.frame_opts.window_type = "hamming";
fbank_opts.frame_opts.frame_shift_ms = 10;
fbank_opts.frame_opts.frame_length_ms = 25;
fbank_opts.energy_floor = 0;
fbank_opts.mel_opts.debug_mel = false;
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
fbank_opts_.energy_floor = 0;
fbank_opts_.mel_opts.debug_mel = false;
// fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
// session_options.SetInterOpNumThreads(1);
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// session_options_.SetInterOpNumThreads(1);
session_options_.SetIntraOpNumThreads(thread_num);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// DisableCpuMemArena can improve performance
session_options.DisableCpuMemArena();
session_options_.DisableCpuMemArena();
try {
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
@ -40,14 +41,14 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
}
string strName;
GetInputName(m_session.get(), strName);
GetInputName(m_session_.get(), strName);
m_strInputNames.push_back(strName.c_str());
GetInputName(m_session.get(), strName,1);
GetInputName(m_session_.get(), strName,1);
m_strInputNames.push_back(strName);
GetOutputName(m_session.get(), strName);
GetOutputName(m_session_.get(), strName);
m_strOutputNames.push_back(strName);
GetOutputName(m_session.get(), strName,1);
GetOutputName(m_session_.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@ -58,6 +59,152 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
LoadCmvn(am_cmvn.c_str());
}
// online
void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
LoadOnlineConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
fbank_opts_.energy_floor = 0;
fbank_opts_.mel_opts.debug_mel = false;
// session_options_.SetInterOpNumThreads(1);
session_options_.SetIntraOpNumThreads(thread_num);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// DisableCpuMemArena can improve performance
session_options_.DisableCpuMemArena();
try {
encoder_session_ = std::make_unique<Ort::Session>(env_, en_model.c_str(), session_options_);
LOG(INFO) << "Successfully load model from " << en_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am encoder model: " << e.what();
exit(0);
}
try {
decoder_session_ = std::make_unique<Ort::Session>(env_, de_model.c_str(), session_options_);
LOG(INFO) << "Successfully load model from " << de_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am decoder model: " << e.what();
exit(0);
}
// encoder
string strName;
GetInputName(encoder_session_.get(), strName);
en_strInputNames.push_back(strName.c_str());
GetInputName(encoder_session_.get(), strName,1);
en_strInputNames.push_back(strName);
GetOutputName(encoder_session_.get(), strName);
en_strOutputNames.push_back(strName);
GetOutputName(encoder_session_.get(), strName,1);
en_strOutputNames.push_back(strName);
GetOutputName(encoder_session_.get(), strName,2);
en_strOutputNames.push_back(strName);
for (auto& item : en_strInputNames)
en_szInputNames_.push_back(item.c_str());
for (auto& item : en_strOutputNames)
en_szOutputNames_.push_back(item.c_str());
// decoder
int de_input_len = 4 + fsmn_layers;
int de_out_len = 2 + fsmn_layers;
for(int i=0;i<de_input_len; i++){
GetInputName(decoder_session_.get(), strName, i);
de_strInputNames.push_back(strName.c_str());
}
for(int i=0;i<de_out_len; i++){
GetOutputName(decoder_session_.get(), strName,i);
de_strOutputNames.push_back(strName);
}
for (auto& item : de_strInputNames)
de_szInputNames_.push_back(item.c_str());
for (auto& item : de_strOutputNames)
de_szOutputNames_.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
LoadCmvn(am_cmvn.c_str());
}
// 2pass
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// online
InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
// offline
try {
m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
exit(0);
}
string strName;
GetInputName(m_session_.get(), strName);
m_strInputNames.push_back(strName.c_str());
GetInputName(m_session_.get(), strName,1);
m_strInputNames.push_back(strName);
GetOutputName(m_session_.get(), strName);
m_strOutputNames.push_back(strName);
GetOutputName(m_session_.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
}
void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
YAML::Node config;
try{
config = YAML::LoadFile(filename);
}catch(exception const &e){
LOG(ERROR) << "Error loading file, yaml file error or not exist.";
exit(-1);
}
try{
YAML::Node frontend_conf = config["frontend_conf"];
YAML::Node encoder_conf = config["encoder_conf"];
YAML::Node decoder_conf = config["decoder_conf"];
YAML::Node predictor_conf = config["predictor_conf"];
this->window_type = frontend_conf["window"].as<string>();
this->n_mels = frontend_conf["n_mels"].as<int>();
this->frame_length = frontend_conf["frame_length"].as<int>();
this->frame_shift = frontend_conf["frame_shift"].as<int>();
this->lfr_m = frontend_conf["lfr_m"].as<int>();
this->lfr_n = frontend_conf["lfr_n"].as<int>();
this->encoder_size = encoder_conf["output_size"].as<int>();
this->fsmn_dims = encoder_conf["output_size"].as<int>();
this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
this->cif_threshold = predictor_conf["threshold"].as<double>();
this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
}
}
Paraformer::~Paraformer()
{
if(vocab)
@ -69,7 +216,7 @@ void Paraformer::Reset()
}
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
knf::OnlineFbank fbank_(fbank_opts);
knf::OnlineFbank fbank_(fbank_opts_);
std::vector<float> buf(len);
for (int32_t i = 0; i != len; ++i) {
buf[i] = waves[i] * 32768;
@ -77,7 +224,7 @@ vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
//fbank_->InputFinished();
int32_t frames = fbank_.NumFramesReady();
int32_t feature_dim = fbank_opts.mel_opts.num_bins;
int32_t feature_dim = fbank_opts_.mel_opts.num_bins;
vector<float> features(frames * feature_dim);
float *p = features.data();
@ -108,7 +255,7 @@ void Paraformer::LoadCmvn(const char *filename)
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
means_list.push_back(stof(means_lines[j]));
means_list_.push_back(stof(means_lines[j]));
}
continue;
}
@ -119,7 +266,7 @@ void Paraformer::LoadCmvn(const char *filename)
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
vars_list.push_back(stof(vars_lines[j])*scale);
vars_list_.push_back(stof(vars_lines[j])*scale);
}
continue;
}
@ -143,11 +290,11 @@ string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
{
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
(in_num_frames - lfr_m) / lfr_n + 1;
int32_t out_feat_dim = in_feat_dim * lfr_m;
std::vector<float> out(out_num_frames * out_feat_dim);
@ -158,7 +305,7 @@ vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
p_in += lfr_window_shift * in_feat_dim;
p_in += lfr_n * in_feat_dim;
}
return out;
@ -166,29 +313,29 @@ vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
void Paraformer::ApplyCmvn(std::vector<float> *v)
{
int32_t dim = means_list.size();
int32_t dim = means_list_.size();
int32_t num_frames = v->size() / dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != dim; ++k) {
p[k] = (p[k] + means_list[k]) * vars_list[k];
p[k] = (p[k] + means_list_[k]) * vars_list_[k];
}
p += dim;
}
}
string Paraformer::Forward(float* din, int len, int flag)
string Paraformer::Forward(float* din, int len, bool input_finished)
{
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
wav_feats = ApplyLfr(wav_feats);
ApplyCmvn(&wav_feats);
int32_t feat_dim = lfr_window_size*in_feat_dim;
int32_t feat_dim = lfr_m*in_feat_dim;
int32_t num_frames = wav_feats.size() / feat_dim;
#ifdef _WIN_X86
@ -216,7 +363,7 @@ string Paraformer::Forward(float* din, int len, int flag)
string result;
try {
auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
@ -232,13 +379,6 @@ string Paraformer::Forward(float* din, int len, int flag)
return result;
}
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
LOG(ERROR)<<"Not Imp!!!!!!";
return "";
}
string Paraformer::Rescoring()
{
LOG(ERROR)<<"Not Imp!!!!!!";

View File

@ -15,38 +15,66 @@ namespace funasr {
* https://arxiv.org/pdf/2206.08317.pdf
*/
private:
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
Vocab* vocab = nullptr;
vector<float> means_list;
vector<float> vars_list;
const float scale = 22.6274169979695;
int32_t lfr_window_size = 7;
int32_t lfr_window_shift = 6;
//const float scale = 22.6274169979695;
const float scale = 1.0;
void LoadOnlineConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
vector<float> ApplyLfr(const vector<float> &in);
void ApplyCmvn(vector<float> *v);
string GreedySearch( float* in, int n_len, int64_t token_nums);
std::shared_ptr<Ort::Session> m_session = nullptr;
Ort::Env env_;
Ort::SessionOptions session_options;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
public:
Paraformer();
~Paraformer();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
// online
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
// 2pass
void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void Reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
string ForwardChunk(float* din, int len, int flag);
string Forward(float* din, int len, int flag);
string Forward(float* din, int len, bool input_finished=true);
string GreedySearch( float* in, int n_len, int64_t token_nums);
string Rescoring();
knf::FbankOptions fbank_opts_;
vector<float> means_list_;
vector<float> vars_list_;
int lfr_m = PARA_LFR_M;
int lfr_n = PARA_LFR_N;
// paraformer-offline
std::shared_ptr<Ort::Session> m_session_ = nullptr;
Ort::Env env_;
Ort::SessionOptions session_options_;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
// paraformer-online
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
vector<string> en_strInputNames, en_strOutputNames;
vector<const char*> en_szInputNames_;
vector<const char*> en_szOutputNames_;
vector<string> de_strInputNames, de_strOutputNames;
vector<const char*> de_szInputNames_;
vector<const char*> de_szOutputNames_;
string window_type = "hamming";
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
int encoder_size = 512;
int fsmn_layers = 16;
int fsmn_lorder = 10;
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;
};
} // namespace funasr

View File

@ -33,18 +33,20 @@ using namespace std;
#include "model.h"
#include "vad-model.h"
#include "punc-model.h"
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "fsmn-vad-online.h"
#include "vocab.h"
#include "audio.h"
#include "fsmn-vad-online.h"
#include "tensor.h"
#include "util.h"
#include "resample.h"
#include "paraformer.h"
#include "paraformer-online.h"
#include "offline-stream.h"
#include "tpass-stream.h"
#include "tpass-online-stream.h"
#include "funasrruntime.h"

View File

@ -0,0 +1,29 @@
#include "precomp.h"
#include <unistd.h>
namespace funasr {
TpassOnlineStream::TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size){
TpassStream* tpass_obj = (TpassStream*)tpass_stream;
if(tpass_obj->vad_handle){
vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(tpass_obj->vad_handle).get());
}else{
LOG(ERROR)<<"asr_handle is null";
exit(-1);
}
if(tpass_obj->asr_handle){
asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
}else{
LOG(ERROR)<<"asr_handle is null";
exit(-1);
}
}
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size)
{
TpassOnlineStream *mm;
mm =new TpassOnlineStream((TpassStream*)tpass_stream, chunk_size);
return mm;
}
} // namespace funasr

View File

@ -0,0 +1,87 @@
#include "precomp.h"
#include <unistd.h>
namespace funasr {
TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thread_num)
{
// VAD model
if(model_path.find(VAD_DIR) != model_path.end()){
string vad_model_path;
string vad_cmvn_path;
string vad_config_path;
vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
}
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
if (access(vad_model_path.c_str(), F_OK) != 0 ||
access(vad_cmvn_path.c_str(), F_OK) != 0 ||
access(vad_config_path.c_str(), F_OK) != 0 )
{
LOG(INFO) << "VAD model file is not exist, skip load vad model.";
}else{
vad_handle = make_unique<FsmnVad>();
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
use_vad = true;
}
}
// AM model
if(model_path.find(OFFLINE_MODEL_DIR) != model_path.end() && model_path.find(ONLINE_MODEL_DIR) != model_path.end()){
// 2pass
string am_model_path;
string en_model_path;
string de_model_path;
string am_cmvn_path;
string am_config_path;
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_DECODER_NAME);
}
am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
asr_handle = make_unique<Paraformer>();
asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
}else{
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
exit(-1);
}
// PUNC model
if(model_path.find(PUNC_DIR) != model_path.end()){
string punc_model_path;
string punc_config_path;
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
if (access(punc_model_path.c_str(), F_OK) != 0 ||
access(punc_config_path.c_str(), F_OK) != 0 )
{
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
}else{
punc_online_handle = make_unique<CTTransformerOnline>();
punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
use_punc = true;
}
}
}
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num)
{
TpassStream *mm;
mm = new TpassStream(model_path, thread_num);
return mm;
}
} // namespace funasr

View File

@ -1,73 +1,27 @@
# Service with grpc-python
We can send streaming audio data to server in real-time with grpc client every 10 ms e.g., and get transcribed text when stop speaking.
The audio data is in streaming, the asr inference process is in offline.
# GRPC python Client for 2pass decoding
The client can send streaming or full audio data to server as you wish, and get transcribed text once the server respond (depends on mode)
## For the Server
### Prepare server environment
Install the modelscope and funasr
In the demo client, audio_chunk_duration is set to 1000ms, and send_interval is set to 100ms
### 1. Install the requirements
```shell
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
git clone https://github.com/alibaba/FunASR.git && cd FunASR
git clone https://github.com/alibaba/FunASR.git && cd FunASR/funasr/runtime/python/grpc
pip install -r requirements.txt
```
Install the requirements
```shell
cd funasr/runtime/python/grpc
pip install -r requirements_server.txt
```
### Generate protobuf file
Run on server, the two generated pb files are both used for server and client
### 2. Generate protobuf file
```shell
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
### Start grpc server
```
# Start server.
python grpc_main_server.py --port 10095 --backend pipeline
```
## For the client
### Install the requirements
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/grpc
pip install -r requirements_client.txt
```
### Generate protobuf file
Run on server, the two generated pb files are both used for server and client
```shell
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
### Start grpc client
### 3. Start grpc client
```
# Start client.
python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
python grpc_main_client.py --host 127.0.0.1 --port 10100 --wav_path /path/to/your_test_wav.wav
```
## Workflow in desgin
<div align="left"><img src="proto/workflow.png" width="400"/>
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge burkliu (刘柏基, liubaiji@xverse.cn) for contributing the grpc service.

View File

@ -1,17 +0,0 @@
import queue
import paraformer_pb2
def transcribe_audio_bytes(stub, chunk, user='zksz', language='zh-CN', speaking = True, isEnd = False):
req = paraformer_pb2.Request()
if chunk is not None:
req.audio_data = chunk
req.user = user
req.language = language
req.speaking = speaking
req.isEnd = isEnd
my_queue = queue.SimpleQueue()
my_queue.put(req)
return stub.Recognize(iter(my_queue.get, None))

View File

@ -1,62 +1,78 @@
import grpc
import json
import time
import asyncio
import soundfile as sf
import logging
import argparse
import soundfile as sf
import time
from grpc_client import transcribe_audio_bytes
from paraformer_pb2_grpc import ASRStub
import grpc
import paraformer_pb2_grpc
from paraformer_pb2 import Request, WavFormat, DecodeMode
# send the audio data once
async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
with grpc.insecure_channel(grpc_uri) as channel:
stub = ASRStub(channel)
for line in wav_scp:
wav_file = line.split()[1]
wav, _ = sf.read(wav_file, dtype='int16')
b = time.time()
response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
resp = response.next()
text = ''
if 'decoding' == resp.action:
resp = response.next()
if 'finish' == resp.action:
text = json.loads(resp.sentence)['text']
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
res= {'text': text, 'time': time.time() - b}
print(res)
class GrpcClient:
def __init__(self, wav_path, uri, mode):
self.wav, self.sampling_rate = sf.read(wav_path, dtype='int16')
self.wav_format = WavFormat.pcm
self.audio_chunk_duration = 1000 # ms
self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
self.send_interval = 100 # ms
self.mode = mode
async def test(args):
wav_scp = open(args.wav_scp, "r").readlines()
uri = '{}:{}'.format(args.host, args.port)
res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
# connect to grpc server
channel = grpc.insecure_channel(uri)
self.stub = paraformer_pb2_grpc.ASRStub(channel)
# start request
for respond in self.stub.Recognize(self.request_iterator()):
logging.info("[receive] mode {}, text {}, is final {}".format(
DecodeMode.Name(respond.mode), respond.text, respond.is_final))
def request_iterator(self, mode = DecodeMode.two_pass):
is_first_pack = True
is_final = False
for start in range(0, len(self.wav), self.audio_chunk_size):
request = Request()
audio_chunk = self.wav[start : start + self.audio_chunk_size]
if is_first_pack:
is_first_pack = False
request.sampling_rate = self.sampling_rate
request.mode = self.mode
request.wav_format = self.wav_format
if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
request.chunk_size.extend([5, 10, 5])
if start + self.audio_chunk_size >= len(self.wav):
is_final = True
request.is_final = is_final
request.audio_data = audio_chunk.tobytes()
logging.info("[request] audio_data len {}, is final {}".format(
len(request.audio_data), request.is_final)) # int16 = 2bytes
time.sleep(self.send_interval / 1000)
yield request
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="127.0.0.1",
required=False,
help="grpc server host ip")
parser.add_argument("--port",
type=int,
default=10108,
required=False,
help="grpc server port")
parser.add_argument("--user_allowed",
type=str,
default="project1_user1",
help="allowed user for grpc client")
parser.add_argument("--sample_rate",
type=int,
default=16000,
help="audio sample_rate from client")
parser.add_argument("--wav_scp",
type=str,
required=True,
help="audio wav scp")
args = parser.parse_args()
asyncio.run(test(args))
logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="127.0.0.1",
required=False,
help="grpc server host ip")
parser.add_argument("--port",
type=int,
default=10100,
required=False,
help="grpc server port")
parser.add_argument("--wav_path",
type=str,
required=True,
help="audio wav path")
args = parser.parse_args()
for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
mode_name = DecodeMode.Name(mode)
logging.info("[request] start requesting with mode {}".format(mode_name))
st = time.time()
uri = '{}:{}'.format(args.host, args.port)
client = GrpcClient(args.wav_path, uri, mode)
logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))

View File

@ -1,112 +0,0 @@
import pyaudio
import grpc
import json
import webrtcvad
import time
import asyncio
import argparse
from grpc_client import transcribe_audio_bytes
from paraformer_pb2_grpc import ASRStub
async def deal_chunk(sig_mic):
global stub,SPEAKING,asr_user,language,sample_rate
if vad.is_speech(sig_mic, sample_rate): #speaking
SPEAKING = True
response = transcribe_audio_bytes(stub, sig_mic, user=asr_user, language=language, speaking = True, isEnd = False) #speaking, send audio to server.
else: #silence
begin_time = 0
if SPEAKING: #means we have some audio recorded, send recognize order to server.
SPEAKING = False
begin_time = int(round(time.time() * 1000))
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = False) #speak end, call server for recognize one sentence
resp = response.next()
if "decoding" == resp.action:
resp = response.next() #TODO, blocking operation may leads to miss some audio clips. C++ multi-threading is preferred.
if "finish" == resp.action:
end_time = int(round(time.time() * 1000))
print (json.loads(resp.sentence))
print ("delay in ms: %d " % (end_time - begin_time))
else:
pass
async def record(host,port,sample_rate,mic_chunk,record_seconds,asr_user,language):
with grpc.insecure_channel('{}:{}'.format(host, port)) as channel:
global stub
stub = ASRStub(channel)
for i in range(0, int(sample_rate / mic_chunk * record_seconds)):
sig_mic = stream.read(mic_chunk,exception_on_overflow = False)
await asyncio.create_task(deal_chunk(sig_mic))
#end grpc
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = True)
print (response.next().action)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="127.0.0.1",
required=True,
help="grpc server host ip")
parser.add_argument("--port",
type=int,
default=10095,
required=True,
help="grpc server port")
parser.add_argument("--user_allowed",
type=str,
default="project1_user1",
help="allowed user for grpc client")
parser.add_argument("--sample_rate",
type=int,
default=16000,
help="audio sample_rate from client")
parser.add_argument("--mic_chunk",
type=int,
default=160,
help="chunk size for mic")
parser.add_argument("--record_seconds",
type=int,
default=120,
help="run specified seconds then exit ")
args = parser.parse_args()
SPEAKING = False
asr_user = args.user_allowed
sample_rate = args.sample_rate
language = 'zh-CN'
vad = webrtcvad.Vad()
vad.set_mode(1)
FORMAT = pyaudio.paInt16
CHANNELS = 1
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT,
channels=CHANNELS,
rate=args.sample_rate,
input=True,
frames_per_buffer=args.mic_chunk)
print("* recording")
asyncio.run(record(args.host,args.port,args.sample_rate,args.mic_chunk,args.record_seconds,args.user_allowed,language))
stream.stop_stream()
stream.close()
p.terminate()
print("recording stop")

View File

@ -1,68 +0,0 @@
import grpc
from concurrent import futures
import argparse
import paraformer_pb2_grpc
from grpc_server import ASRServicer
def serve(args):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
# interceptors=(AuthInterceptor('Bearer mysecrettoken'),)
)
paraformer_pb2_grpc.add_ASRServicer_to_server(
ASRServicer(args.user_allowed, args.model, args.sample_rate, args.backend, args.onnx_dir, vad_model=args.vad_model, punc_model=args.punc_model), server)
port = "[::]:" + str(args.port)
server.add_insecure_port(port)
server.start()
print("grpc server started!")
server.wait_for_termination()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port",
type=int,
default=10095,
required=True,
help="grpc server port")
parser.add_argument("--user_allowed",
type=str,
default="project1_user1|project1_user2|project2_user3",
help="allowed user for grpc client")
parser.add_argument("--model",
type=str,
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope")
parser.add_argument("--vad_model",
type=str,
default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope")
parser.add_argument("--punc_model",
type=str,
default="",
help="model from modelscope")
parser.add_argument("--sample_rate",
type=int,
default=16000,
help="audio sample_rate from client")
parser.add_argument("--backend",
type=str,
default="pipeline",
choices=("pipeline", "onnxruntime"),
help="backend, optional modelscope pipeline or onnxruntime")
parser.add_argument("--onnx_dir",
type=str,
default="/nfs/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="onnx model dir")
args = parser.parse_args()
serve(args)

View File

@ -1,132 +0,0 @@
from concurrent import futures
import grpc
import json
import time
import paraformer_pb2_grpc
from paraformer_pb2 import Response
class ASRServicer(paraformer_pb2_grpc.ASRServicer):
def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir, vad_model='', punc_model=''):
print("ASRServicer init")
self.backend = backend
self.init_flag = 0
self.client_buffers = {}
self.client_transcription = {}
self.auth_user = user_allowed.split("|")
if self.backend == "pipeline":
try:
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
except ImportError:
raise ImportError(f"Please install modelscope")
self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model, vad_model=vad_model, punc_model=punc_model)
elif self.backend == "onnxruntime":
try:
from funasr_onnx import Paraformer
except ImportError:
raise ImportError(f"Please install onnxruntime environment")
self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir)
self.sample_rate = sample_rate
def clear_states(self, user):
self.clear_buffers(user)
self.clear_transcriptions(user)
def clear_buffers(self, user):
if user in self.client_buffers:
del self.client_buffers[user]
def clear_transcriptions(self, user):
if user in self.client_transcription:
del self.client_transcription[user]
def disconnect(self, user):
self.clear_states(user)
print("Disconnecting user: %s" % str(user))
def Recognize(self, request_iterator, context):
for req in request_iterator:
if req.user not in self.auth_user:
result = {}
result["success"] = False
result["detail"] = "Not Authorized user: %s " % req.user
result["text"] = ""
yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
elif req.isEnd: #end grpc
print("asr end")
self.disconnect(req.user)
result = {}
result["success"] = True
result["detail"] = "asr end"
result["text"] = ""
yield Response(sentence=json.dumps(result), user=req.user, action="terminate",language=req.language)
elif req.speaking: #continue speaking
if req.audio_data is not None and len(req.audio_data) > 0:
if req.user in self.client_buffers:
self.client_buffers[req.user] += req.audio_data #append audio
else:
self.client_buffers[req.user] = req.audio_data
result = {}
result["success"] = True
result["detail"] = "speaking"
result["text"] = ""
yield Response(sentence=json.dumps(result), user=req.user, action="speaking", language=req.language)
elif not req.speaking: #silence
if req.user not in self.client_buffers:
result = {}
result["success"] = True
result["detail"] = "waiting_for_more_voice"
result["text"] = ""
yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
else:
begin_time = int(round(time.time() * 1000))
tmp_data = self.client_buffers[req.user]
self.clear_states(req.user)
result = {}
result["success"] = True
result["detail"] = "decoding data: %d bytes" % len(tmp_data)
result["text"] = ""
yield Response(sentence=json.dumps(result), user=req.user, action="decoding", language=req.language)
if len(tmp_data) < 9600: #min input_len for asr model , 300ms
end_time = int(round(time.time() * 1000))
delay_str = str(end_time - begin_time)
result = {}
result["success"] = True
result["detail"] = "waiting_for_more_voice"
result["server_delay_ms"] = delay_str
result["text"] = ""
print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice"))
yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
else:
if self.backend == "pipeline":
asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate)
if "text" in asr_result:
asr_result = asr_result['text']
else:
asr_result = ""
elif self.backend == "onnxruntime":
from funasr_onnx.utils.frontend import load_bytes
array = load_bytes(tmp_data)
asr_result = self.inference_16k_pipeline(array)[0]
end_time = int(round(time.time() * 1000))
delay_str = str(end_time - begin_time)
print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result))
result = {}
result["success"] = True
result["detail"] = "finish_sentence"
result["server_delay_ms"] = delay_str
result["text"] = asr_result
yield Response(sentence=json.dumps(result), user=req.user, action="finish", language=req.language)
else:
result = {}
result["success"] = False
result["detail"] = "error, no condition matched! Unknown reason."
result["text"] = ""
self.disconnect(req.user)
yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)

View File

@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: paraformer.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10paraformer.proto\x12\nparaformer\"^\n\x07Request\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x10\n\x08speaking\x18\x04 \x01(\x08\x12\r\n\x05isEnd\x18\x05 \x01(\x08\"L\n\x08Response\x12\x10\n\x08sentence\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0e\n\x06\x61\x63tion\x18\x04 \x01(\t2C\n\x03\x41SR\x12<\n\tRecognize\x12\x13.paraformer.Request\x1a\x14.paraformer.Response\"\x00(\x01\x30\x01\x42\x16\n\x07\x65x.grpc\xa2\x02\nparaformerb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'paraformer_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\007ex.grpc\242\002\nparaformer'
_REQUEST._serialized_start=32
_REQUEST._serialized_end=126
_RESPONSE._serialized_start=128
_RESPONSE._serialized_end=204
_ASR._serialized_start=206
_ASR._serialized_end=273
# @@protoc_insertion_point(module_scope)

View File

@ -1,66 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import paraformer_pb2 as paraformer__pb2
class ASRStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Recognize = channel.stream_stream(
'/paraformer.ASR/Recognize',
request_serializer=paraformer__pb2.Request.SerializeToString,
response_deserializer=paraformer__pb2.Response.FromString,
)
class ASRServicer(object):
"""Missing associated documentation comment in .proto file."""
def Recognize(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_ASRServicer_to_server(servicer, server):
rpc_method_handlers = {
'Recognize': grpc.stream_stream_rpc_method_handler(
servicer.Recognize,
request_deserializer=paraformer__pb2.Request.FromString,
response_serializer=paraformer__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'paraformer.ASR', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class ASR(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Recognize(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/paraformer.ASR/Recognize',
paraformer__pb2.Request.SerializeToString,
paraformer__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -1,3 +1,8 @@
// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
// Reserved. MIT License (https://opensource.org/licenses/MIT)
//
// 2023 by burkliu() liubaiji@xverse.cn
syntax = "proto3";
option objc_class_prefix = "paraformer";
@ -8,17 +13,27 @@ service ASR {
rpc Recognize (stream Request) returns (stream Response) {}
}
enum WavFormat {
pcm = 0;
}
enum DecodeMode {
offline = 0;
online = 1;
two_pass = 2;
}
message Request {
bytes audio_data = 1;
string user = 2;
string language = 3;
bool speaking = 4;
bool isEnd = 5;
DecodeMode mode = 1;
WavFormat wav_format = 2;
int32 sampling_rate = 3;
repeated int32 chunk_size = 4;
bool is_final = 5;
bytes audio_data = 6;
}
message Response {
string sentence = 1;
string user = 2;
string language = 3;
string action = 4;
DecodeMode mode = 1;
string text = 2;
bool is_final = 3;
}

View File

@ -1,4 +1,2 @@
pyaudio
webrtcvad
grpcio
grpcio-tools

View File

@ -1,2 +0,0 @@
grpcio
grpcio-tools

View File

@ -0,0 +1,30 @@
import soundfile
from funasr_onnx.paraformer_online_bin import Paraformer
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
wav_path = '{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav'.format(Path.home())
chunk_size = [5, 10, 5]
model = Paraformer(model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4) # only support batch_size = 1
##online asr
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
sample_offset = 0
step = chunk_size[1] * 960
param_dict = {'cache': dict()}
final_result = ""
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
param_dict['is_final'] = is_final
rec_result = model(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
if len(rec_result) > 0:
final_result += rec_result[0]["preds"][0]
print(rec_result)
print(final_result)

View File

@ -0,0 +1,309 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, get_logger,
read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
logging = get_logger()
class Paraformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
chunk_size: List = [5, 10, 5],
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None
):
if not Path(model_dir).exists():
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
encoder_model_file = os.path.join(model_dir, 'model.onnx')
decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
if quantize:
encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.export.export_model import ModelExport
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
device="cpu",
quant=quantize,
)
export_model.export(model_dir)
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
self.converter = TokenIDConverter(config['token_list'])
self.tokenizer = CharTokenizer()
self.frontend = WavFrontendOnline(
cmvn_file=cmvn_file,
**config['frontend_conf']
)
self.pe = SinusoidalPositionEncoderOnline()
self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
intra_op_num_threads=intra_op_num_threads)
self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
intra_op_num_threads=intra_op_num_threads)
self.batch_size = batch_size
self.chunk_size = chunk_size
self.encoder_output_size = config["encoder_conf"]["output_size"]
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
self.fsmn_dims = config["encoder_conf"]["output_size"]
self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
self.cif_threshold = config["predictor_conf"]["threshold"]
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
def prepare_cache(self, cache: dict = {}, batch_size=1):
if len(cache) > 0:
return cache
cache["start_idx"] = 0
cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
cache["chunk_size"] = self.chunk_size
cache["last_chunk"] = False
cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
cache["decoder_fsmn"] = []
for i in range(self.fsmn_layer):
fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
cache["decoder_fsmn"].append(fsmn_cache)
return cache
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
# process last chunk
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
if cache["is_final"]:
cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
if not cache["last_chunk"]:
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
else:
cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
return overlap_feats
def __call__(self, audio_in: np.ndarray, **kwargs):
waveforms = np.expand_dims(audio_in, axis=0)
param_dict = kwargs.get('param_dict', dict())
is_final = param_dict.get('is_final', False)
cache = param_dict.get('cache', dict())
asr_res = []
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
cache["last_chunk"] = True
feats = cache["feats"]
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
feats, feats_len = self.extract_feat(waveforms, is_final)
if feats.shape[1] != 0:
feats *= self.encoder_output_size ** 0.5
cache = self.prepare_cache(cache)
cache["is_final"] = is_final
# fbank -> position encoding -> overlap chunk
feats = self.pe.forward(feats, cache["start_idx"])
cache["start_idx"] += feats.shape[1]
if is_final:
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
cache["last_chunk"] = True
feats = self.add_overlap_chunk(feats, cache)
else:
# first chunk
feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache["last_chunk"] = True
feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
res = {}
for pred in asr_res_chunk:
for key, value in pred.items():
if key in res:
res[key][0] += value[0]
res[key][1].extend(value[1])
else:
res[key] = [value[0], value[1]]
return [res]
else:
feats = self.add_overlap_chunk(feats, cache)
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
# encoder forward
enc_input = [feats, feats_len]
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
# predictor forward
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
# decoder forward
asr_res = []
if acoustic_embeds.shape[1] > 0:
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
dec_input.extend(cache["decoder_fsmn"])
dec_output = self.ort_decoder_infer(dec_input)
logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
preds = self.decode(logits, acoustic_embeds_len)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({'preds': pred})
return asr_res
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
def extract_feat(self,
waveforms: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
for idx, waveform in enumerate(waveforms):
waveforms_lens[idx] = waveform.shape[-1]
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
return feats.astype(np.float32), feats_len.astype(np.int32)
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)]
def decode_one(self,
am_score: np.ndarray,
valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[:valid_token_num]
# texts = sentence_postprocess(token)
return token
def cif_search(self, hidden, alphas, cache=None):
batch_size, len_time, hidden_size = hidden.shape
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
alphas[:, :self.chunk_size[0]] = 0.0
alphas[:, sum(self.chunk_size[:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
tail_alphas =np.tile(tail_alphas, (batch_size, 1))
hidden = np.concatenate((hidden, tail_hidden), axis=1)
alphas = np.concatenate((alphas, tail_alphas), axis=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = np.zeros(hidden_size).astype(np.float32)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.cif_threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.cif_threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.cif_threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(len(list_frame))
list_fires.append(list_fire)
list_frames.append(list_frame)
max_token_len = max(token_length)
list_ls = []
for b in range(batch_size):
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)

View File

@ -349,6 +349,28 @@ def load_bytes(input):
return array
class SinusoidalPositionEncoderOnline():
'''Streaming Positional encoding.
'''
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
batch_size = positions.shape[0]
positions = positions.astype(dtype)
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
return encoding.astype(dtype)
def forward(self, x, start_idx=0):
batch_size, timesteps, input_dim = x.shape
positions = np.arange(1, timesteps+1+start_idx)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype)
return x + position_encoding[:, start_idx: start_idx + timesteps]
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa

View File

@ -58,7 +58,11 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
find_package(OpenSSL REQUIRED)
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
add_executable(funasr-wss-client "funasr-wss-client.cpp")
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp")
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)

View File

@ -0,0 +1,430 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2022-2023 by zhaomingwork */
// client for websocket, support multiple threads
// ./funasr-wss-client --server-ip <string>
// --port <string>
// --wav-path <string>
// [--thread-num <int>]
// [--is-ssl <int>] [--]
// [--version] [-h]
// example:
// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio_client.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
#include <atomic>
#include <thread>
#include <glog/logging.h>
#include "audio.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
/**
* Define a semi-cross platform helper method that waits/sleeps for a bit.
*/
void WaitABit() {
#ifdef WIN32
Sleep(1000);
#else
sleep(1);
#endif
}
std::atomic<int> wav_index(0);
bool IsTargetFile(const std::string& filename, const std::string target) {
std::size_t pos = filename.find_last_of(".");
if (pos == std::string::npos) {
return false;
}
std::string extension = filename.substr(pos + 1);
return (extension == target);
}
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context> context_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
context_ptr OnTlsInit(websocketpp::connection_hdl) {
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
asio::ssl::context::sslv23);
try {
ctx->set_options(
asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
} catch (std::exception& e) {
LOG(ERROR) << e.what();
}
return ctx;
}
// template for tls or not config
template <typename T>
class WebsocketClient {
public:
// typedef websocketpp::client<T> client;
// typedef websocketpp::client<websocketpp::config::asio_tls_client>
// wss_client;
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
WebsocketClient(int is_ssl) : m_open(false), m_done(false) {
// set up access channels to only log interesting things
m_client.clear_access_channels(websocketpp::log::alevel::all);
m_client.set_access_channels(websocketpp::log::alevel::connect);
m_client.set_access_channels(websocketpp::log::alevel::disconnect);
m_client.set_access_channels(websocketpp::log::alevel::app);
// Initialize the Asio transport policy
m_client.init_asio();
// Bind the handlers we are using
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
m_client.set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
on_message(hdl, msg);
});
m_client.set_fail_handler(bind(&WebsocketClient::on_fail, this, _1));
m_client.clear_access_channels(websocketpp::log::alevel::all);
}
void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
const std::string& payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
nlohmann::json jsonresult = nlohmann::json::parse(payload);
LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
// if (jsonresult["is_final"] == true){
// websocketpp::lib::error_code ec;
// m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
// if (ec){
// LOG(ERROR)<< "Error closing connection " << ec.message();
// }
// }
}
}
// This method will block until the connection is complete
void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string asr_mode, std::vector<int> chunk_size) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
m_client.get_connection(uri, ec);
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Get Connection Error: " + ec.message());
return;
}
// Grab a handle for this connection so we can talk to it in a thread
// safe manor after the event loop starts.
m_hdl = con->get_handle();
// Queue the connection. No DNS queries or network connections will be
// made until the io_service event loop is run.
m_client.connect(con);
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
while(true){
int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
send_wav_data(wav_list[i], wav_ids[i], asr_mode, chunk_size);
}
WaitABit();
asio_thread.join();
}
// The open handler will signal that we are ready to start sending data
void on_open(websocketpp::connection_hdl) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Connection opened, starting data!");
scoped_lock guard(m_lock);
m_open = true;
}
// The close handler will signal that we should stop sending data
void on_close(websocketpp::connection_hdl) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Connection closed, stopping data!");
scoped_lock guard(m_lock);
m_done = true;
}
// The fail handler will signal that we should stop sending data
void on_fail(websocketpp::connection_hdl) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Connection failed, stopping data!");
scoped_lock guard(m_lock);
m_done = true;
}
// send wav to server
void send_wav_data(string wav_path, string wav_id, std::string asr_mode, std::vector<int> chunk_vector) {
uint64_t count = 0;
std::stringstream val;
funasr::Audio audio(1);
int32_t sampling_rate = 16000;
std::string wav_format = "pcm";
if(IsTargetFile(wav_path.c_str(), "wav")){
int32_t sampling_rate = -1;
if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
return ;
}else if(IsTargetFile(wav_path.c_str(), "pcm")){
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
return ;
}else{
wav_format = "others";
if (!audio.LoadOthers2Char(wav_path.c_str()))
return ;
}
float* buff;
int len;
int flag = 0;
bool wait = false;
while (1) {
{
scoped_lock guard(m_lock);
// If the connection has been closed, stop generating data
if (m_done) {
break;
}
// If the connection hasn't been opened yet wait a bit and retry
if (!m_open) {
wait = true;
} else {
break;
}
}
if (wait) {
// LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
}
websocketpp::lib::error_code ec;
nlohmann::json jsonbegin;
nlohmann::json chunk_size = nlohmann::json::array();
chunk_size.push_back(chunk_vector[0]);
chunk_size.push_back(chunk_vector[1]);
chunk_size.push_back(chunk_vector[2]);
jsonbegin["mode"] = asr_mode;
jsonbegin["chunk_size"] = chunk_size;
jsonbegin["wav_name"] = wav_id;
jsonbegin["wav_format"] = wav_format;
jsonbegin["is_speaking"] = true;
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
ec);
// fetch wav data use asr engine api
if(wav_format == "pcm"){
while (audio.Fetch(buff, len, flag) > 0) {
short* iArray = new short[len];
for (size_t i = 0; i < len; ++i) {
iArray[i] = (short)(buff[i]*32768);
}
// send data to server
int offset = 0;
int block_size = 102400;
while(offset < len){
int send_block = 0;
if (offset + block_size <= len){
send_block = block_size;
}else{
send_block = len - offset;
}
m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
LOG(INFO) << "sended data len=" << len * sizeof(short);
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
break;
}
delete[] iArray;
// WaitABit();
}
}else{
int offset = 0;
int block_size = 204800;
len = audio.GetSpeechLen();
char* others_buff = audio.GetSpeechChar();
while(offset < len){
int send_block = 0;
if (offset + block_size <= len){
send_block = block_size;
}else{
send_block = len - offset;
}
m_client.send(m_hdl, others_buff+offset, send_block,
websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
LOG(INFO) << "sended data len=" << len;
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
}
}
nlohmann::json jsonresult;
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
ec);
// WaitABit();
}
websocketpp::client<T> m_client;
private:
websocketpp::connection_hdl m_hdl;
websocketpp::lib::mutex m_lock;
bool m_open;
bool m_done;
int total_num=0;
};
int main(int argc, char* argv[]) {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
"127.0.0.1", "string");
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
true, "", "string");
TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size", "chunk_size: 5-10-5 or 5-12-5", false, "5-10-5", "string");
TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
false, 1, "int");
TCLAP::ValueArg<int> is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
false, 1, "int");
cmd.add(server_ip_);
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(asr_mode_);
cmd.add(chunk_size_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
cmd.parse(argc, argv);
std::string server_ip = server_ip_.getValue();
std::string port = port_.getValue();
std::string wav_path = wav_path_.getValue();
std::string asr_mode = asr_mode_.getValue();
std::string chunk_size_str = chunk_size_.getValue();
// get chunk_size
std::vector<int> chunk_size;
std::stringstream ss(chunk_size_str);
std::string item;
while (std::getline(ss, item, '-')) {
try {
chunk_size.push_back(stoi(item));
} catch (const invalid_argument&) {
LOG(ERROR) << "Invalid argument: " << item;
exit(-1);
}
}
int threads_num = thread_num_.getValue();
int is_ssl = is_ssl_.getValue();
std::vector<websocketpp::lib::thread> client_threads;
std::string uri = "";
if (is_ssl == 1) {
uri = "wss://" + server_ip + ":" + port;
} else {
uri = "ws://" + server_ip + ":" + port;
}
// read wav_path
std::vector<string> wav_list;
std::vector<string> wav_ids;
string default_id = "wav_default_id";
if(IsTargetFile(wav_path, "scp")){
ifstream in(wav_path);
if (!in.is_open()) {
printf("Failed to open scp file");
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
}else{
wav_list.emplace_back(wav_path);
wav_ids.emplace_back(default_id);
}
for (size_t i = 0; i < threads_num; i++) {
client_threads.emplace_back([uri, wav_list, wav_ids, asr_mode, chunk_size, is_ssl]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
}
});
}
for (auto& t : client_threads) {
t.join();
}
}

View File

@ -0,0 +1,419 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2022-2023 by zhaomingwork */
// io server
// Usage:funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num
// <int>]
// [--io_thread_num <int>] [--port <int>] [--listen_ip
// <string>] [--punc-quant <string>] [--punc-dir <string>]
// [--vad-quant <string>] [--vad-dir <string>] [--quantize
// <string>] --model-dir <string> [--] [--version] [-h]
#include <unistd.h>
#include "websocket-server-2pass.h"
using namespace std;
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
std::map<std::string, std::string>& model_path) {
model_path.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
}
int main(int argc, char* argv[]) {
try {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
TCLAP::ValueArg<std::string> download_model_dir(
"", "download-model-dir",
"Download model from Modelscope to download_model_dir", false,
"/workspace/models", "string");
TCLAP::ValueArg<std::string> offline_model_dir(
"", OFFLINE_MODEL_DIR,
"default: /workspace/models/offline_asr, the asr model path, which "
"contains model_quant.onnx, config.yaml, am.mvn",
false, "/workspace/models/offline_asr", "string");
TCLAP::ValueArg<std::string> online_model_dir(
"", ONLINE_MODEL_DIR,
"default: damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx, the asr model path, which "
"contains model_quant.onnx, config.yaml, am.mvn",
false, "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx", "string");
TCLAP::ValueArg<std::string> offline_model_revision(
"", "offline-model-revision", "ASR offline model revision", false,
"v1.2.1", "string");
TCLAP::ValueArg<std::string> online_model_revision(
"", "online-model-revision", "ASR online model revision", false,
"v1.0.6", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
"true (Default), load the model of model_quant.onnx in model_dir. If "
"set "
"false, load the model of model.onnx in model_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
"default: /workspace/models/vad, the vad model path, which contains "
"model_quant.onnx, vad.yaml, vad.mvn",
false, "/workspace/models/vad", "string");
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
"false, load the model of model.onnx in vad_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir(
"", PUNC_DIR,
"default: /workspace/models/punc, the punc model path, which contains "
"model_quant.onnx, punc.yaml",
false, "/workspace/models/punc", "string");
TCLAP::ValueArg<std::string> punc_revision(
"", "punc-revision", "PUNC model revision", false, "v1.0.2", "string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
"true (Default), load the model of model_quant.onnx in punc_dir. If "
"set "
"false, load the model of model.onnx in punc_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
"0.0.0.0", "string");
TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int");
TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
false, 8, "int");
TCLAP::ValueArg<int> decoder_thread_num(
"", "decoder-thread-num", "decoder thread num", false, 8, "int");
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
"model thread num", false, 4, "int");
TCLAP::ValueArg<std::string> certfile(
"", "certfile",
"default: ../../../ssl_key/server.crt, path of certficate for WSS "
"connection. if it is empty, it will be in WS mode.",
false, "../../../ssl_key/server.crt", "string");
TCLAP::ValueArg<std::string> keyfile(
"", "keyfile",
"default: ../../../ssl_key/server.key, path of keyfile for WSS "
"connection",
false, "../../../ssl_key/server.key", "string");
cmd.add(certfile);
cmd.add(keyfile);
cmd.add(download_model_dir);
cmd.add(offline_model_dir);
cmd.add(online_model_dir);
cmd.add(offline_model_revision);
cmd.add(online_model_revision);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_revision);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_revision);
cmd.add(punc_quant);
cmd.add(listen_ip);
cmd.add(port);
cmd.add(io_thread_num);
cmd.add(decoder_thread_num);
cmd.add(model_thread_num);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(offline_model_revision, "offline-model-revision", model_path);
GetValue(online_model_revision, "online-model-revision", model_path);
GetValue(vad_revision, "vad-revision", model_path);
GetValue(punc_revision, "punc-revision", model_path);
// Download model form Modelscope
try {
std::string s_download_model_dir = download_model_dir.getValue();
std::string s_vad_path = model_path[VAD_DIR];
std::string s_vad_quant = model_path[VAD_QUANT];
std::string s_offline_asr_path = model_path[OFFLINE_MODEL_DIR];
std::string s_online_asr_path = model_path[ONLINE_MODEL_DIR];
std::string s_asr_quant = model_path[QUANTIZE];
std::string s_punc_path = model_path[PUNC_DIR];
std::string s_punc_quant = model_path[PUNC_QUANT];
std::string python_cmd =
"python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
if (vad_dir.isSet() && !s_vad_path.empty()) {
std::string python_cmd_vad;
std::string down_vad_path;
std::string down_vad_model;
if (access(s_vad_path.c_str(), F_OK) == 0) {
// local
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
" --export-dir ./ " + " --model_revision " +
model_path["vad-revision"];
down_vad_path = s_vad_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_vad_path
<< " from modelscope: ";
python_cmd_vad = python_cmd + " --model-name " +
s_vad_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["vad-revision"];
down_vad_path =
s_download_model_dir +
"/" + s_vad_path;
}
int ret = system(python_cmd_vad.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set local vad model path, you can ignore the errors.";
}
down_vad_model = down_vad_path + "/model_quant.onnx";
if (s_vad_quant == "false" || s_vad_quant == "False" ||
s_vad_quant == "FALSE") {
down_vad_model = down_vad_path + "/model.onnx";
}
if (access(down_vad_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_vad_model << " do not exists.";
exit(-1);
} else {
model_path[VAD_DIR] = down_vad_path;
LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
}
}
else {
LOG(INFO) << "VAD model is not set, use default.";
}
if (offline_model_dir.isSet() && !s_offline_asr_path.empty()) {
std::string python_cmd_asr;
std::string down_asr_path;
std::string down_asr_model;
if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
// local
python_cmd_asr = python_cmd + " --model-name " + s_offline_asr_path +
" --export-dir ./ " + " --model_revision " +
model_path["offline-model-revision"];
down_asr_path = s_offline_asr_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_offline_asr_path
<< " from modelscope : ";
python_cmd_asr = python_cmd + " --model-name " +
s_offline_asr_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["offline-model-revision"];
down_asr_path
= s_download_model_dir + "/" + s_offline_asr_path;
}
int ret = system(python_cmd_asr.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
}
down_asr_model = down_asr_path + "/model_quant.onnx";
if (s_asr_quant == "false" || s_asr_quant == "False" ||
s_asr_quant == "FALSE") {
down_asr_model = down_asr_path + "/model.onnx";
}
if (access(down_asr_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_asr_model << " do not exists.";
exit(-1);
} else {
model_path[OFFLINE_MODEL_DIR] = down_asr_path;
LOG(INFO) << "Set " << OFFLINE_MODEL_DIR << " : " << model_path[OFFLINE_MODEL_DIR];
}
} else {
LOG(INFO) << "ASR Offline model is not set, use default.";
}
if (!s_online_asr_path.empty()) {
std::string python_cmd_asr;
std::string down_asr_path;
std::string down_asr_model;
if (access(s_online_asr_path.c_str(), F_OK) == 0) {
// local
python_cmd_asr = python_cmd + " --model-name " + s_online_asr_path +
" --export-dir ./ " + " --model_revision " +
model_path["online-model-revision"];
down_asr_path = s_online_asr_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_online_asr_path
<< " from modelscope : ";
python_cmd_asr = python_cmd + " --model-name " +
s_online_asr_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["online-model-revision"];
down_asr_path
= s_download_model_dir + "/" + s_online_asr_path;
}
int ret = system(python_cmd_asr.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
}
down_asr_model = down_asr_path + "/model_quant.onnx";
if (s_asr_quant == "false" || s_asr_quant == "False" ||
s_asr_quant == "FALSE") {
down_asr_model = down_asr_path + "/model.onnx";
}
if (access(down_asr_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_asr_model << " do not exists.";
exit(-1);
} else {
model_path[ONLINE_MODEL_DIR] = down_asr_path;
LOG(INFO) << "Set " << ONLINE_MODEL_DIR << " : " << model_path[ONLINE_MODEL_DIR];
}
} else {
LOG(INFO) << "ASR online model is not set, use default.";
}
if (punc_dir.isSet() && !s_punc_path.empty()) {
std::string python_cmd_punc;
std::string down_punc_path;
std::string down_punc_model;
if (access(s_punc_path.c_str(), F_OK) == 0) {
// local
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
" --export-dir ./ " + " --model_revision " +
model_path["punc-revision"];
down_punc_path = s_punc_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_punc_path
<< " from modelscope : "; python_cmd_punc = python_cmd + " --model-name " +
s_punc_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["punc-revision"];
down_punc_path =
s_download_model_dir +
"/" + s_punc_path;
}
int ret = system(python_cmd_punc.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set local punc model path, you can ignore the errors.";
}
down_punc_model = down_punc_path + "/model_quant.onnx";
if (s_punc_quant == "false" || s_punc_quant == "False" ||
s_punc_quant == "FALSE") {
down_punc_model = down_punc_path + "/model.onnx";
}
if (access(down_punc_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_punc_model << " do not exists.";
exit(-1);
} else {
model_path[PUNC_DIR] = down_punc_path;
LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
}
} else {
LOG(INFO) << "PUNC model is not set, use default.";
}
} catch (std::exception const& e) {
LOG(ERROR) << "Error: " << e.what();
}
std::string s_listen_ip = listen_ip.getValue();
int s_port = port.getValue();
int s_io_thread_num = io_thread_num.getValue();
int s_decoder_thread_num = decoder_thread_num.getValue();
int s_model_thread_num = model_thread_num.getValue();
asio::io_context io_decoder; // context for decoding
asio::io_context io_server; // context for server
std::vector<std::thread> decoder_threads;
std::string s_certfile = certfile.getValue();
std::string s_keyfile = keyfile.getValue();
bool is_ssl = false;
if (!s_certfile.empty()) {
is_ssl = true;
}
auto conn_guard = asio::make_work_guard(
io_decoder); // make sure threads can wait in the queue
auto server_guard = asio::make_work_guard(
io_server); // make sure threads can wait in the queue
// create threads pool
for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
}
server server_; // server for websocket
wss_server wss_server_;
if (is_ssl) {
wss_server_.init_asio(&io_server); // init asio
wss_server_.set_reuse_addr(
true); // reuse address as we create multiple threads
// list on port for accept
wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
WebSocketServer websocket_srv(
io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
s_keyfile); // websocket server for asr engine
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
} else {
server_.init_asio(&io_server); // init asio
server_.set_reuse_addr(
true); // reuse address as we create multiple threads
// list on port for accept
server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
WebSocketServer websocket_srv(
io_decoder, is_ssl, &server_, nullptr, s_certfile,
s_keyfile); // websocket server for asr engine
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
}
std::cout << "asr model init finished. listen on port:" << s_port
<< std::endl;
// Start the ASIO network io_service run loop
std::vector<std::thread> ts;
// create threads for io network
for (size_t i = 0; i < s_io_thread_num; i++) {
ts.emplace_back([&io_server]() { io_server.run(); });
}
// wait for theads
for (size_t i = 0; i < s_io_thread_num; i++) {
ts[i].join();
}
// wait for theads
for (auto& t : decoder_threads) {
t.join();
}
} catch (std::exception const& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
return 0;
}

View File

@ -79,7 +79,7 @@ int main(int argc, char* argv[]) {
TCLAP::ValueArg<int> decoder_thread_num(
"", "decoder-thread-num", "decoder thread num", false, 8, "int");
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
"model thread num", false, 1, "int");
"model thread num", false, 4, "int");
TCLAP::ValueArg<std::string> certfile("", "certfile",
"default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",

View File

@ -116,6 +116,18 @@ Export Detailed Introduction[docs](https://github.com/alibaba-damo-academy/Fu
--punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
```
##### Start the 2pass Service
```shell
./funasr-wss-server-2pass \
--download-model-dir /workspace/models \
--offline-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
--vad-dir ./exportdamo/speech_fsmn_vad_zh-cn-16k-common-onnx \
--punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
--online-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
--quantize false
```
### Client Usage

View File

@ -0,0 +1,369 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2022-2023 by zhaomingwork */
// websocket server for asr engine
// take some ideas from https://github.com/k2-fsa/sherpa-onnx
// online-websocket-server-impl.cc, thanks. The websocket server has two threads
// pools, one for handle network data and one for asr decoder.
// now only support offline engine.
#include "websocket-server-2pass.h"
#include <thread>
#include <utility>
#include <vector>
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
websocketpp::connection_hdl hdl,
std::string& s_certfile,
std::string& s_keyfile) {
namespace asio = websocketpp::lib::asio;
LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
LOG(INFO) << "using TLS mode: "
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
: "Mozilla Intermediate");
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
asio::ssl::context::sslv23);
try {
if (mode == MOZILLA_MODERN) {
// Modern disables TLSv1
ctx->set_options(
asio::ssl::context::default_workarounds |
asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
} else {
ctx->set_options(asio::ssl::context::default_workarounds |
asio::ssl::context::no_sslv2 |
asio::ssl::context::no_sslv3 |
asio::ssl::context::single_dh_use);
}
ctx->use_certificate_chain_file(s_certfile);
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
} catch (std::exception& e) {
LOG(INFO) << "Exception: " << e.what();
}
return ctx;
}
nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
std::string& tpass_res, nlohmann::json msg) {
websocketpp::lib::error_code ec;
nlohmann::json jsonresult;
jsonresult["text"]="";
std::string tmp_online_msg = FunASRGetResult(result, 0);
online_res += tmp_online_msg;
if (tmp_online_msg != "") {
LOG(INFO) << "online_res :" << tmp_online_msg;
jsonresult["text"] = tmp_online_msg;
jsonresult["mode"] = "2pass-online";
}
std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
tpass_res += tmp_tpass_msg;
if (tmp_tpass_msg != "") {
LOG(INFO) << "offline results : " << tmp_tpass_msg;
jsonresult["text"] = tmp_tpass_msg;
jsonresult["mode"] = "2pass-offline";
}
if (msg.contains("wav_name")) {
jsonresult["wav_name"] = msg["wav_name"];
}
return jsonresult;
}
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(
std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
websocketpp::lib::mutex& thread_lock, bool& is_final,
FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
std::string& tpass_res) {
// lock for each connection
scoped_lock guard(thread_lock);
FUNASR_RESULT Result = nullptr;
int asr_mode_ = 2;
if (msg.contains("mode")) {
std::string modeltype = msg["mode"];
if (modeltype == "offline") {
asr_mode_ = 0;
} else if (modeltype == "online") {
asr_mode_ = 1;
} else if (modeltype == "2pass") {
asr_mode_ = 2;
}
} else {
// default value
msg["mode"] = "2pass";
asr_mode_ = 2;
}
try {
// loop to send chunk_size 800*2 data to asr engine. TODO: chunk_size need get from client
while (buffer.size() >= 800 * 2) {
std::vector<char> subvector = {buffer.begin(),
buffer.begin() + 800 * 2};
buffer.erase(buffer.begin(), buffer.begin() + 800 * 2);
try{
Result =
FunTpassInferBuffer(tpass_handle, tpass_online_handle,
subvector.data(), subvector.size(), punc_cache,
false, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
}
if (Result) {
websocketpp::lib::error_code ec;
nlohmann::json jsonresult =
handle_result(Result, online_res, tpass_res, msg["wav_name"]);
jsonresult["is_final"] = false;
if(jsonresult["text"] != "") {
if (is_ssl) {
wss_server_->send(hdl, jsonresult.dump(),
websocketpp::frame::opcode::text, ec);
} else {
server_->send(hdl, jsonresult.dump(),
websocketpp::frame::opcode::text, ec);
}
}
FunASRFreeResult(Result);
}
}
if(is_final){
try{
Result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
buffer.data(), buffer.size(), punc_cache,
is_final, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
}
for(auto &vec:punc_cache){
vec.clear();
}
if (Result) {
websocketpp::lib::error_code ec;
nlohmann::json jsonresult =
handle_result(Result, online_res, tpass_res, msg["wav_name"]);
jsonresult["is_final"] = true;
if (is_ssl) {
wss_server_->send(hdl, jsonresult.dump(),
websocketpp::frame::opcode::text, ec);
} else {
server_->send(hdl, jsonresult.dump(),
websocketpp::frame::opcode::text, ec);
}
FunASRFreeResult(Result);
}
}
} catch (std::exception const& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock); // for threads safty
check_and_clean_connection(); // remove closed connection
std::shared_ptr<FUNASR_MESSAGE> data_msg =
std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
// connection
data_msg->samples = std::make_shared<std::vector<char>>();
data_msg->thread_lock = new websocketpp::lib::mutex();
data_msg->msg = nlohmann::json::parse("{}");
data_msg->msg["wav_format"] = "pcm";
data_msg->msg["audio_fs"] = 16000;
data_msg->punc_cache =
std::make_shared<std::vector<std::vector<std::string>>>(2);
// std::vector<int> chunk_size = {5, 10, 5}; //TODO, need get from client
// FUNASR_HANDLE tpass_online_handle =
// FunTpassOnlineInit(tpass_handle, chunk_size);
// data_msg->tpass_online_handle = tpass_online_handle;
data_map.emplace(hdl, data_msg);
LOG(INFO) << "on_open, active connections: " << data_map.size();
}
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock);
std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
auto it_data = data_map.find(hdl);
if (it_data != data_map.end()) {
data_msg = it_data->second;
}
else
{
return;
}
scoped_lock guard_decoder(*(data_msg->thread_lock)); //wait for do_decoder finished and avoid access freed tpass_online_handle
FunTpassOnlineUninit(data_msg->tpass_online_handle);
data_map.erase(hdl); // remove data vector when connection is closed
LOG(INFO) << "on_close, active connections: "<< data_map.size();
}
// remove closed connection
void WebSocketServer::check_and_clean_connection() {
std::vector<websocketpp::connection_hdl> to_remove; // remove list
auto iter = data_map.begin();
while (iter != data_map.end()) { // loop to find closed connection
websocketpp::connection_hdl hdl = iter->first;
if (is_ssl) {
wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
if (con->get_state() != 1) { // session::state::open ==1
to_remove.push_back(hdl);
}
} else {
server::connection_ptr con = server_->get_con_from_hdl(hdl);
if (con->get_state() != 1) { // session::state::open ==1
to_remove.push_back(hdl);
}
}
iter++;
}
for (auto hdl : to_remove) {
data_map.erase(hdl);
LOG(INFO) << "remove one connection ";
}
}
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
message_ptr msg) {
unique_lock lock(m_lock);
// find the sample data vector according to one connection
std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
auto it_data = data_map.find(hdl);
if (it_data != data_map.end()) {
msg_data = it_data->second;
}
std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
msg_data->punc_cache;
websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
lock.unlock();
if (sample_data_p == nullptr) {
LOG(INFO) << "error when fetch sample data vector";
return;
}
const std::string& payload = msg->get_payload(); // get msg type
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text: {
nlohmann::json jsonresult = nlohmann::json::parse(payload);
if (jsonresult.contains("wav_name")) {
msg_data->msg["wav_name"] = jsonresult["wav_name"];
}
if (jsonresult.contains("mode")) {
msg_data->msg["mode"] = jsonresult["mode"];
}
if (jsonresult.contains("wav_format")) {
msg_data->msg["wav_format"] = jsonresult["wav_format"];
}
if (jsonresult.contains("audio_fs")) {
msg_data->msg["audio_fs"] = jsonresult["audio_fs"];
}
if (jsonresult.contains("chunk_size")){
if(msg_data->tpass_online_handle == NULL){
std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
FUNASR_HANDLE tpass_online_handle =
FunTpassOnlineInit(tpass_handle, chunk_size_vec);
msg_data->tpass_online_handle = tpass_online_handle;
}
}
LOG(INFO) << "jsonresult=" << jsonresult << ", msg_data->msg="
<< msg_data->msg;
if (jsonresult["is_speaking"] == false ||
jsonresult["is_finished"] == true) {
LOG(INFO) << "client done";
// if it is in final message, post the sample_data to decode
asio::post(
io_decoder_,
std::bind(&WebSocketServer::do_decoder, this,
std::move(*(sample_data_p.get())), std::move(hdl),
std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
std::ref(*thread_lock_p), std::move(true),
std::ref(msg_data->tpass_online_handle),
std::ref(msg_data->online_res),
std::ref(msg_data->tpass_res)));
}
break;
}
case websocketpp::frame::opcode::binary: {
// recived binary data
const auto* pcm_data = static_cast<const char*>(payload.data());
int32_t num_samples = payload.size();
if (isonline) {
sample_data_p->insert(sample_data_p->end(), pcm_data,
pcm_data + num_samples);
int setpsize = 800 * 2; // TODO, need get from client
// if sample_data size > setpsize, we post data to decode
if (sample_data_p->size() > setpsize) {
int chunksize = floor(sample_data_p->size() / setpsize);
// make sure the subvector size is an integer multiple of setpsize
std::vector<char> subvector = {
sample_data_p->begin(),
sample_data_p->begin() + chunksize * setpsize};
// keep remain in sample_data
sample_data_p->erase(sample_data_p->begin(),
sample_data_p->begin() + chunksize * setpsize);
// post to decode
asio::post(io_decoder_,
std::bind(&WebSocketServer::do_decoder, this,
std::move(subvector), std::move(hdl),
std::ref(msg_data->msg),
std::ref(*(punc_cache_p.get())),
std::ref(*thread_lock_p), std::move(false),
std::ref(msg_data->tpass_online_handle),
std::ref(msg_data->online_res),
std::ref(msg_data->tpass_res)));
}
} else {
sample_data_p->insert(sample_data_p->end(), pcm_data,
pcm_data + num_samples);
}
break;
}
default:
break;
}
}
// init asr model
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
int thread_num) {
try {
tpass_handle = FunTpassInit(model_path, thread_num);
if (!tpass_handle) {
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
}
} catch (const std::exception& e) {
LOG(INFO) << e.what();
}
}

View File

@ -0,0 +1,148 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2022-2023 by zhaomingwork */
// websocket server for asr engine
// take some ideas from https://github.com/k2-fsa/sherpa-onnx
// online-websocket-server-impl.cc, thanks. The websocket server has two threads
// pools, one for handle network data and one for asr decoder.
// now only support offline engine.
#ifndef WEBSOCKET_SERVER_H_
#define WEBSOCKET_SERVER_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#define ASIO_STANDALONE 1 // not boost
#include <glog/logging.h>
#include <fstream>
#include <functional>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio.hpp>
#include <websocketpp/server.hpp>
#include "asio.hpp"
#include "com-define.h"
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
typedef websocketpp::server<websocketpp::config::asio> server;
typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
typedef server::message_ptr message_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
context_ptr;
typedef struct {
std::string msg;
float snippet_time;
} FUNASR_RECOG_RESULT;
typedef struct {
nlohmann::json msg;
std::shared_ptr<std::vector<char>> samples;
std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache;
websocketpp::lib::mutex* thread_lock; // lock for each connection
FUNASR_HANDLE tpass_online_handle=NULL;
std::string online_res = "";
std::string tpass_res = "";
} FUNASR_MESSAGE;
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
// the TLS modes. The code below demonstrates how to implement both the modern
enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
class WebSocketServer {
public:
WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
wss_server* wss_server, std::string& s_certfile,
std::string& s_keyfile)
: io_decoder_(io_decoder),
is_ssl(is_ssl),
server_(server),
wss_server_(wss_server) {
if (is_ssl) {
std::cout << "certfile path is " << s_certfile << std::endl;
wss_server->set_tls_init_handler(
bind<context_ptr>(&WebSocketServer::on_tls_init, this,
MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
wss_server_->set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
on_message(hdl, msg);
});
// set open handle
wss_server_->set_open_handler(
[this](websocketpp::connection_hdl hdl) { on_open(hdl); });
// set close handle
wss_server_->set_close_handler(
[this](websocketpp::connection_hdl hdl) { on_close(hdl); });
// begin accept
wss_server_->start_accept();
// not print log
wss_server_->clear_access_channels(websocketpp::log::alevel::all);
} else {
// set message handle
server_->set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
on_message(hdl, msg);
});
// set open handle
server_->set_open_handler(
[this](websocketpp::connection_hdl hdl) { on_open(hdl); });
// set close handle
server_->set_close_handler(
[this](websocketpp::connection_hdl hdl) { on_close(hdl); });
// begin accept
server_->start_accept();
// not print log
server_->clear_access_channels(websocketpp::log::alevel::all);
}
}
void do_decoder(std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
nlohmann::json& msg,
std::vector<std::vector<std::string>>& punc_cache,
websocketpp::lib::mutex& thread_lock, bool& is_final,
FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
std::string& tpass_res);
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
void on_open(websocketpp::connection_hdl hdl);
void on_close(websocketpp::connection_hdl hdl);
context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
std::string& s_certfile, std::string& s_keyfile);
private:
void check_and_clean_connection();
asio::io_context& io_decoder_; // threads for asr decoder
// std::ofstream fout;
// FUNASR_HANDLE asr_handle; // asr engine handle
FUNASR_HANDLE tpass_handle=NULL;
bool isonline = true; // online or offline engine, now only support offline
bool is_ssl = true;
server* server_; // websocket server
wss_server* wss_server_; // websocket server
// use map to keep the received samples data from one connection in offline
// engine. if for online engline, a data struct is needed(TODO)
std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
std::owner_less<websocketpp::connection_hdl>>
data_map;
websocketpp::lib::mutex m_lock; // mutex for sample_map
};
#endif // WEBSOCKET_SERVER_H_