diff --git a/funasr/runtime/grpc/CMakeLists.txt b/funasr/runtime/grpc/CMakeLists.txt index c7727d57c..98c478752 100644 --- a/funasr/runtime/grpc/CMakeLists.txt +++ b/funasr/runtime/grpc/CMakeLists.txt @@ -42,17 +42,23 @@ add_custom_command( "${rg_proto}" DEPENDS "${rg_proto}") - # Include generated *.pb.h files include_directories("${CMAKE_CURRENT_BINARY_DIR}") -include_directories(../onnxruntime/include/) -link_directories(../onnxruntime/build/src/) -link_directories(../onnxruntime/build/third_party/yaml-cpp/) - link_directories(${ONNXRUNTIME_DIR}/lib) + +include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/) +include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/) +include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank) + +add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp) +add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc) add_subdirectory("../onnxruntime/src" onnx_src) +include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog) +set(BUILD_TESTING OFF) +add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog) + # rg_grpc_proto add_library(rg_grpc_proto ${rg_grpc_srcs} @@ -60,16 +66,13 @@ add_library(rg_grpc_proto ${rg_proto_srcs} ${rg_proto_hdrs}) - - target_link_libraries(rg_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF}) -# Targets paraformer_(server) foreach(_target - paraformer_server) + paraformer-server) add_executable(${_target} "${_target}.cc") target_link_libraries(${_target} diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md index 23e618c22..da925599b 100644 --- a/funasr/runtime/grpc/Readme.md +++ b/funasr/runtime/grpc/Readme.md @@ -4,15 +4,6 @@ ### Build [onnxruntime](./onnxruntime_cpp.md) as it's document -``` -#put onnx-lib & onnx-asr-model into /path/to/asrmodel(eg: /data/asrmodel) -ls /data/asrmodel/ -onnxruntime-linux-x64-1.14.0 speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch - -#make sure you have config.yaml, am.mvn, model.onnx(or model_quant.onnx) under speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch - -``` - ### Compile and install grpc v1.52.0 in case of grpc bugs ``` export GRPC_INSTALL_DIR=/data/soft/grpc @@ -46,8 +37,39 @@ source ~/.bashrc ### Start grpc paraformer server ``` -Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false) -./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false +./cmake/build/paraformer-server --port-id [--punc-config + ] [--punc-model ] + --am-config --am-cmvn + --am-model [--vad-config + ] [--vad-cmvn ] + [--vad-model ] [--] [--version] + [-h] +Where: + --port-id + (required) port id + + --am-config + (required) am config path + --am-cmvn + (required) am cmvn path + --am-model + (required) am model path + + --punc-config + punc config path + --punc-model + punc model path + + --vad-config + vad config path + --vad-cmvn + vad cmvn path + --vad-model + vad model path + + Required: --port-id --am-config --am-cmvn --am-model + If use vad, please add: [--vad-config ] [--vad-cmvn ] [--vad-model ] + If use punc, please add: [--punc-config ] [--punc-model ] ``` ## For the client diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer-server.cc similarity index 65% rename from funasr/runtime/grpc/paraformer_server.cc rename to funasr/runtime/grpc/paraformer-server.cc index 2893d4cfb..31333c9eb 100644 --- a/funasr/runtime/grpc/paraformer_server.cc +++ b/funasr/runtime/grpc/paraformer-server.cc @@ -13,7 +13,10 @@ #include #include "paraformer.grpc.pb.h" -#include "paraformer_server.h" +#include "paraformer-server.h" +#include "tclap/CmdLine.h" +#include "com-define.h" +#include "glog/logging.h" using grpc::Server; using grpc::ServerBuilder; @@ -27,31 +30,43 @@ using paraformer::Request; using paraformer::Response; using paraformer::ASR; -ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) { - AsrHanlde=FunASRInit(model_path, thread_num, quantize); +ASRServicer::ASRServicer(std::map& model_path) { + AsrHanlde=FunASRInit(model_path, 1); std::cout << "ASRServicer init" << std::endl; init_flag = 0; } +void ASRServicer::clear_states(const std::string& user) { + clear_buffers(user); + clear_transcriptions(user); +} + +void ASRServicer::clear_buffers(const std::string& user) { + if (client_buffers.count(user)) { + client_buffers.erase(user); + } +} + +void ASRServicer::clear_transcriptions(const std::string& user) { + if (client_transcription.count(user)) { + client_transcription.erase(user); + } +} + +void ASRServicer::disconnect(const std::string& user) { + clear_states(user); + std::cout << "Disconnecting user: " << user << std::endl; +} + grpc::Status ASRServicer::Recognize( grpc::ServerContext* context, grpc::ServerReaderWriter* stream) { Request req; - std::unordered_map client_buffers; - std::unordered_map client_transcription; - while (stream->Read(&req)) { if (req.isend()) { std::cout << "asr end" << std::endl; - // disconnect - if (client_buffers.count(req.user())) { - client_buffers.erase(req.user()); - } - if (client_transcription.count(req.user())) { - client_transcription.erase(req.user()); - } - + disconnect(req.user()); Response res; res.set_sentence( R"({"success": true, "detail": "asr end"})" @@ -89,14 +104,8 @@ grpc::Status ASRServicer::Recognize( auto& buf = client_buffers[req.user()]; buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end()); } - std::string tmp_data = client_buffers[req.user()]; - // clear_states - if (client_buffers.count(req.user())) { - client_buffers.erase(req.user()); - } - if (client_transcription.count(req.user())) { - client_transcription.erase(req.user()); - } + std::string tmp_data = this->client_buffers[req.user()]; + this->clear_states(req.user()); Response res; res.set_sentence( @@ -161,10 +170,17 @@ grpc::Status ASRServicer::Recognize( return Status::OK; } -void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) { +void RunServer(std::map& model_path) { + std::string port; + try{ + port = model_path.at(PORT_ID); + }catch(std::exception const &e){ + printf("Error when read port.\n"); + exit(0); + } std::string server_address; server_address = "0.0.0.0:" + port; - ASRServicer service(model_path, thread_num, quantize); + ASRServicer service(model_path); ServerBuilder builder; builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); @@ -174,16 +190,54 @@ void RunServer(const std::string& port, int thread_num, const char* model_path, server->Wait(); } -int main(int argc, char* argv[]) { - if (argc < 5) - { - printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]); - exit(-1); +void GetValue(TCLAP::ValueArg& value_arg, std::string key, std::map& model_path) +{ + if (value_arg.isSet()){ + model_path.insert({key, value_arg.getValue()}); + LOG(INFO)<< key << " : " << value_arg.getValue(); } +} - // is quantize - bool quantize = false; - std::istringstream(argv[4]) >> std::boolalpha >> quantize; - RunServer(argv[1], atoi(argv[2]), argv[3], quantize); +int main(int argc, char* argv[]) { + + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = true; + + TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0"); + TCLAP::ValueArg vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string"); + TCLAP::ValueArg vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string"); + TCLAP::ValueArg vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string"); + + TCLAP::ValueArg am_model("", AM_MODEL_PATH, "am model path", true, "", "string"); + TCLAP::ValueArg am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string"); + TCLAP::ValueArg am_config("", AM_CONFIG_PATH, "am config path", true, "", "string"); + + TCLAP::ValueArg punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string"); + TCLAP::ValueArg punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string"); + TCLAP::ValueArg port_id("", PORT_ID, "port id", true, "", "string"); + + cmd.add(vad_model); + cmd.add(vad_cmvn); + cmd.add(vad_config); + cmd.add(am_model); + cmd.add(am_cmvn); + cmd.add(am_config); + cmd.add(punc_model); + cmd.add(punc_config); + cmd.add(port_id); + cmd.parse(argc, argv); + + std::map model_path; + GetValue(vad_model, VAD_MODEL_PATH, model_path); + GetValue(vad_cmvn, VAD_CMVN_PATH, model_path); + GetValue(vad_config, VAD_CONFIG_PATH, model_path); + GetValue(am_model, AM_MODEL_PATH, model_path); + GetValue(am_cmvn, AM_CMVN_PATH, model_path); + GetValue(am_config, AM_CONFIG_PATH, model_path); + GetValue(punc_model, PUNC_MODEL_PATH, model_path); + GetValue(punc_config, PUNC_CONFIG_PATH, model_path); + GetValue(port_id, PORT_ID, model_path); + + RunServer(model_path); return 0; } diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer-server.h similarity index 70% rename from funasr/runtime/grpc/paraformer_server.h rename to funasr/runtime/grpc/paraformer-server.h index dba1e45c2..108e3b688 100644 --- a/funasr/runtime/grpc/paraformer_server.h +++ b/funasr/runtime/grpc/paraformer-server.h @@ -37,13 +37,18 @@ typedef struct float snippet_time; }FUNASR_RECOG_RESULT; - class ASRServicer final : public ASR::Service { private: int init_flag; + std::unordered_map client_buffers; + std::unordered_map client_transcription; public: - ASRServicer(const char* model_path, int thread_num, bool quantize); + ASRServicer(std::map& model_path); + void clear_states(const std::string& user); + void clear_buffers(const std::string& user); + void clear_transcriptions(const std::string& user); + void disconnect(const std::string& user); grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter* stream); FUNASR_HANDLE AsrHanlde; diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h index 8c885178e..9b7b212b7 100644 --- a/funasr/runtime/onnxruntime/include/com-define.h +++ b/funasr/runtime/onnxruntime/include/com-define.h @@ -24,6 +24,7 @@ #define WAV_PATH "wav-path" #define WAV_SCP "wav-scp" #define THREAD_NUM "thread-num" +#define PORT_ID "port-id" // vad #ifndef VAD_SILENCE_DURATION diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h index 8dca7f4d6..f65efccfc 100644 --- a/funasr/runtime/onnxruntime/include/libfunasrapi.h +++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h @@ -47,10 +47,9 @@ typedef enum { typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step. -// APIs for funasr +// // ASR _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map& model_path, int thread_num); -// if not give a fn_callback ,it should be NULL _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); @@ -62,6 +61,14 @@ _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result); _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle); _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result); +// VAD +_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map& model_path, int thread_num); + +_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback); + #ifdef __cplusplus } diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp index 93434bb73..01aa38a8c 100644 --- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp +++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp @@ -11,6 +11,12 @@ extern "C" { return mm; } + _FUNASRAPI FUNASR_HANDLE FunVadInit(std::map& model_path, int thread_num) + { + Model* mm = CreateModel(model_path, thread_num); + return mm; + } + _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback) { Model* recog_obj = (Model*)handle; diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h index cf69ad976..68e0fe840 100644 --- a/funasr/runtime/onnxruntime/src/precomp.h +++ b/funasr/runtime/onnxruntime/src/precomp.h @@ -21,8 +21,8 @@ using namespace std; // third part #include "onnxruntime_run_options_config_keys.h" #include "onnxruntime_cxx_api.h" -#include -#include +#include "kaldi-native-fbank/csrc/feature-fbank.h" +#include "kaldi-native-fbank/csrc/online-feature.h" // mine #include @@ -40,6 +40,7 @@ using namespace std; #include "util.h" #include "resample.h" #include "model.h" +#include "vad-model.h" #include "paraformer.h" #include "libfunasrapi.h"