mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix paraformer server for new apis
This commit is contained in:
parent
b78d47f1ef
commit
a539392ad4
@ -42,17 +42,23 @@ add_custom_command(
|
||||
"${rg_proto}"
|
||||
DEPENDS "${rg_proto}")
|
||||
|
||||
|
||||
# Include generated *.pb.h files
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
include_directories(../onnxruntime/include/)
|
||||
link_directories(../onnxruntime/build/src/)
|
||||
link_directories(../onnxruntime/build/third_party/yaml-cpp/)
|
||||
|
||||
link_directories(${ONNXRUNTIME_DIR}/lib)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
|
||||
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
|
||||
add_subdirectory("../onnxruntime/src" onnx_src)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
|
||||
set(BUILD_TESTING OFF)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
|
||||
|
||||
# rg_grpc_proto
|
||||
add_library(rg_grpc_proto
|
||||
${rg_grpc_srcs}
|
||||
@ -60,16 +66,13 @@ add_library(rg_grpc_proto
|
||||
${rg_proto_srcs}
|
||||
${rg_proto_hdrs})
|
||||
|
||||
|
||||
|
||||
target_link_libraries(rg_grpc_proto
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
|
||||
# Targets paraformer_(server)
|
||||
foreach(_target
|
||||
paraformer_server)
|
||||
paraformer-server)
|
||||
add_executable(${_target}
|
||||
"${_target}.cc")
|
||||
target_link_libraries(${_target}
|
||||
|
||||
@ -4,15 +4,6 @@
|
||||
|
||||
### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
|
||||
|
||||
```
|
||||
#put onnx-lib & onnx-asr-model into /path/to/asrmodel(eg: /data/asrmodel)
|
||||
ls /data/asrmodel/
|
||||
onnxruntime-linux-x64-1.14.0 speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
|
||||
#make sure you have config.yaml, am.mvn, model.onnx(or model_quant.onnx) under speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
|
||||
```
|
||||
|
||||
### Compile and install grpc v1.52.0 in case of grpc bugs
|
||||
```
|
||||
export GRPC_INSTALL_DIR=/data/soft/grpc
|
||||
@ -46,8 +37,39 @@ source ~/.bashrc
|
||||
|
||||
### Start grpc paraformer server
|
||||
```
|
||||
Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
|
||||
./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
|
||||
./cmake/build/paraformer-server --port-id <string> [--punc-config
|
||||
<string>] [--punc-model <string>]
|
||||
--am-config <string> --am-cmvn <string>
|
||||
--am-model <string> [--vad-config
|
||||
<string>] [--vad-cmvn <string>]
|
||||
[--vad-model <string>] [--] [--version]
|
||||
[-h]
|
||||
Where:
|
||||
--port-id <string>
|
||||
(required) port id
|
||||
|
||||
--am-config <string>
|
||||
(required) am config path
|
||||
--am-cmvn <string>
|
||||
(required) am cmvn path
|
||||
--am-model <string>
|
||||
(required) am model path
|
||||
|
||||
--punc-config <string>
|
||||
punc config path
|
||||
--punc-model <string>
|
||||
punc model path
|
||||
|
||||
--vad-config <string>
|
||||
vad config path
|
||||
--vad-cmvn <string>
|
||||
vad cmvn path
|
||||
--vad-model <string>
|
||||
vad model path
|
||||
|
||||
Required: --port-id <string> --am-config <string> --am-cmvn <string> --am-model <string>
|
||||
If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
|
||||
If use punc, please add: [--punc-config <string>] [--punc-model <string>]
|
||||
```
|
||||
|
||||
## For the client
|
||||
|
||||
@ -13,7 +13,10 @@
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
|
||||
#include "paraformer.grpc.pb.h"
|
||||
#include "paraformer_server.h"
|
||||
#include "paraformer-server.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
@ -27,31 +30,43 @@ using paraformer::Request;
|
||||
using paraformer::Response;
|
||||
using paraformer::ASR;
|
||||
|
||||
ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
|
||||
AsrHanlde=FunASRInit(model_path, thread_num, quantize);
|
||||
ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
|
||||
AsrHanlde=FunASRInit(model_path, 1);
|
||||
std::cout << "ASRServicer init" << std::endl;
|
||||
init_flag = 0;
|
||||
}
|
||||
|
||||
void ASRServicer::clear_states(const std::string& user) {
|
||||
clear_buffers(user);
|
||||
clear_transcriptions(user);
|
||||
}
|
||||
|
||||
void ASRServicer::clear_buffers(const std::string& user) {
|
||||
if (client_buffers.count(user)) {
|
||||
client_buffers.erase(user);
|
||||
}
|
||||
}
|
||||
|
||||
void ASRServicer::clear_transcriptions(const std::string& user) {
|
||||
if (client_transcription.count(user)) {
|
||||
client_transcription.erase(user);
|
||||
}
|
||||
}
|
||||
|
||||
void ASRServicer::disconnect(const std::string& user) {
|
||||
clear_states(user);
|
||||
std::cout << "Disconnecting user: " << user << std::endl;
|
||||
}
|
||||
|
||||
grpc::Status ASRServicer::Recognize(
|
||||
grpc::ServerContext* context,
|
||||
grpc::ServerReaderWriter<Response, Request>* stream) {
|
||||
|
||||
Request req;
|
||||
std::unordered_map<std::string, std::string> client_buffers;
|
||||
std::unordered_map<std::string, std::string> client_transcription;
|
||||
|
||||
while (stream->Read(&req)) {
|
||||
if (req.isend()) {
|
||||
std::cout << "asr end" << std::endl;
|
||||
// disconnect
|
||||
if (client_buffers.count(req.user())) {
|
||||
client_buffers.erase(req.user());
|
||||
}
|
||||
if (client_transcription.count(req.user())) {
|
||||
client_transcription.erase(req.user());
|
||||
}
|
||||
|
||||
disconnect(req.user());
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "asr end"})"
|
||||
@ -89,14 +104,8 @@ grpc::Status ASRServicer::Recognize(
|
||||
auto& buf = client_buffers[req.user()];
|
||||
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
|
||||
}
|
||||
std::string tmp_data = client_buffers[req.user()];
|
||||
// clear_states
|
||||
if (client_buffers.count(req.user())) {
|
||||
client_buffers.erase(req.user());
|
||||
}
|
||||
if (client_transcription.count(req.user())) {
|
||||
client_transcription.erase(req.user());
|
||||
}
|
||||
std::string tmp_data = this->client_buffers[req.user()];
|
||||
this->clear_states(req.user());
|
||||
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
@ -161,10 +170,17 @@ grpc::Status ASRServicer::Recognize(
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
|
||||
void RunServer(std::map<std::string, std::string>& model_path) {
|
||||
std::string port;
|
||||
try{
|
||||
port = model_path.at(PORT_ID);
|
||||
}catch(std::exception const &e){
|
||||
printf("Error when read port.\n");
|
||||
exit(0);
|
||||
}
|
||||
std::string server_address;
|
||||
server_address = "0.0.0.0:" + port;
|
||||
ASRServicer service(model_path, thread_num, quantize);
|
||||
ASRServicer service(model_path);
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
@ -174,16 +190,54 @@ void RunServer(const std::string& port, int thread_num, const char* model_path,
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 5)
|
||||
{
|
||||
printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
|
||||
exit(-1);
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
// is quantize
|
||||
bool quantize = false;
|
||||
std::istringstream(argv[4]) >> std::boolalpha >> quantize;
|
||||
RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
|
||||
|
||||
cmd.add(vad_model);
|
||||
cmd.add(vad_cmvn);
|
||||
cmd.add(vad_config);
|
||||
cmd.add(am_model);
|
||||
cmd.add(am_cmvn);
|
||||
cmd.add(am_config);
|
||||
cmd.add(punc_model);
|
||||
cmd.add(punc_config);
|
||||
cmd.add(port_id);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(vad_model, VAD_MODEL_PATH, model_path);
|
||||
GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
|
||||
GetValue(vad_config, VAD_CONFIG_PATH, model_path);
|
||||
GetValue(am_model, AM_MODEL_PATH, model_path);
|
||||
GetValue(am_cmvn, AM_CMVN_PATH, model_path);
|
||||
GetValue(am_config, AM_CONFIG_PATH, model_path);
|
||||
GetValue(punc_model, PUNC_MODEL_PATH, model_path);
|
||||
GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
|
||||
GetValue(port_id, PORT_ID, model_path);
|
||||
|
||||
RunServer(model_path);
|
||||
return 0;
|
||||
}
|
||||
@ -37,13 +37,18 @@ typedef struct
|
||||
float snippet_time;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
|
||||
class ASRServicer final : public ASR::Service {
|
||||
private:
|
||||
int init_flag;
|
||||
std::unordered_map<std::string, std::string> client_buffers;
|
||||
std::unordered_map<std::string, std::string> client_transcription;
|
||||
|
||||
public:
|
||||
ASRServicer(const char* model_path, int thread_num, bool quantize);
|
||||
ASRServicer(std::map<std::string, std::string>& model_path);
|
||||
void clear_states(const std::string& user);
|
||||
void clear_buffers(const std::string& user);
|
||||
void clear_transcriptions(const std::string& user);
|
||||
void disconnect(const std::string& user);
|
||||
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
|
||||
FUNASR_HANDLE AsrHanlde;
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#define WAV_PATH "wav-path"
|
||||
#define WAV_SCP "wav-scp"
|
||||
#define THREAD_NUM "thread-num"
|
||||
#define PORT_ID "port-id"
|
||||
|
||||
// vad
|
||||
#ifndef VAD_SILENCE_DURATION
|
||||
|
||||
@ -47,10 +47,9 @@ typedef enum {
|
||||
|
||||
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
|
||||
|
||||
// APIs for funasr
|
||||
// // ASR
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
|
||||
// if not give a fn_callback ,it should be NULL
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
@ -62,6 +61,14 @@ _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
|
||||
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
|
||||
|
||||
// VAD
|
||||
_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
|
||||
@ -11,6 +11,12 @@ extern "C" {
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
Model* mm = CreateModel(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
|
||||
{
|
||||
Model* recog_obj = (Model*)handle;
|
||||
|
||||
@ -21,8 +21,8 @@ using namespace std;
|
||||
// third part
|
||||
#include "onnxruntime_run_options_config_keys.h"
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include <kaldi-native-fbank/csrc/feature-fbank.h>
|
||||
#include <kaldi-native-fbank/csrc/online-feature.h>
|
||||
#include "kaldi-native-fbank/csrc/feature-fbank.h"
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
|
||||
// mine
|
||||
#include <glog/logging.h>
|
||||
@ -40,6 +40,7 @@ using namespace std;
|
||||
#include "util.h"
|
||||
#include "resample.h"
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
#include "paraformer.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user