diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h index ab9f420a2..a61a68fe9 100644 --- a/funasr/runtime/onnxruntime/include/audio.h +++ b/funasr/runtime/onnxruntime/include/audio.h @@ -1,10 +1,10 @@ - #ifndef AUDIO_H #define AUDIO_H #include #include -#include "model.h" +#include "vad-model.h" +#include "offline-stream.h" #ifndef WAV_HEADER_SIZE #define WAV_HEADER_SIZE 44 @@ -54,7 +54,8 @@ class Audio { int FetchChunck(float *&dout, int len); int Fetch(float *&dout, int &len, int &flag); void Padding(); - void Split(Model* recog_obj); + void Split(OfflineStream* offline_streamj); + void Split(VadModel* vad_obj, vector>& vad_segments); float GetTimeLen(); int GetQueueSize() { return (int)frame_queue.size(); } }; diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h index 9b7b212b7..ad3bd35d3 100644 --- a/funasr/runtime/onnxruntime/include/com-define.h +++ b/funasr/runtime/onnxruntime/include/com-define.h @@ -12,20 +12,37 @@ #define MODEL_SAMPLE_RATE 16000 #endif -// model path -#define VAD_MODEL_PATH "vad-model" -#define VAD_CMVN_PATH "vad-cmvn" -#define VAD_CONFIG_PATH "vad-config" -#define AM_MODEL_PATH "am-model" -#define AM_CMVN_PATH "am-cmvn" -#define AM_CONFIG_PATH "am-config" -#define PUNC_MODEL_PATH "punc-model" -#define PUNC_CONFIG_PATH "punc-config" +// parser option +#define MODEL_DIR "model-dir" +#define VAD_DIR "vad-dir" +#define PUNC_DIR "punc-dir" +#define QUANTIZE "quantize" +#define VAD_QUANT "vad-quant" +#define PUNC_QUANT "punc-quant" + #define WAV_PATH "wav-path" #define WAV_SCP "wav-scp" +#define TXT_PATH "txt-path" #define THREAD_NUM "thread-num" #define PORT_ID "port-id" +// #define VAD_MODEL_PATH "vad-model" +// #define VAD_CMVN_PATH "vad-cmvn" +// #define VAD_CONFIG_PATH "vad-config" +// #define AM_MODEL_PATH "am-model" +// #define AM_CMVN_PATH "am-cmvn" +// #define AM_CONFIG_PATH "am-config" +// #define PUNC_MODEL_PATH "punc-model" +// #define PUNC_CONFIG_PATH "punc-config" + +#define MODEL_NAME "model.onnx" +#define QUANT_MODEL_NAME "model_quant.onnx" +#define VAD_CMVN_NAME "vad.mvn" +#define VAD_CONFIG_NAME "vad.yaml" +#define AM_CMVN_NAME "am.mvn" +#define AM_CONFIG_NAME "config.yaml" +#define PUNC_CONFIG_NAME "punc.yaml" + // vad #ifndef VAD_SILENCE_DURATION #define VAD_SILENCE_DURATION 800 diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h index f65efccfc..152db6183 100644 --- a/funasr/runtime/onnxruntime/include/libfunasrapi.h +++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h @@ -1,5 +1,6 @@ #pragma once #include +#include #ifdef WIN32 #ifdef _FUNASR_API_EXPORT @@ -47,8 +48,8 @@ typedef enum { typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step. -// // ASR -_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map& model_path, int thread_num); +// ASR +_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map& model_path, int thread_num); _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); @@ -62,12 +63,23 @@ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle); _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result); // VAD -_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map& model_path, int thread_num); +_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map& model_path, int thread_num); -_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback); -_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); -_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback); -_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI FUNASR_RESULT FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI std::vector>* FunVadGetResult(FUNASR_RESULT result,int n_index); +_FUNASRAPI void FunVadFreeResult(FUNASR_RESULT result); +_FUNASRAPI void FunVadUninit(FUNASR_HANDLE handle); +_FUNASRAPI const float FunVadGetRetSnippetTime(FUNASR_RESULT result); + +// PUNC +_FUNASRAPI FUNASR_HANDLE FunPuncInit(std::map& model_path, int thread_num); +_FUNASRAPI const std::string FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI void FunPuncUninit(FUNASR_HANDLE handle); + +//OfflineStream +_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num); +_FUNASRAPI FUNASR_RESULT FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback); +_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle); #ifdef __cplusplus diff --git a/funasr/runtime/onnxruntime/include/model.h b/funasr/runtime/onnxruntime/include/model.h index 4b4b582ff..786fd28d1 100644 --- a/funasr/runtime/onnxruntime/include/model.h +++ b/funasr/runtime/onnxruntime/include/model.h @@ -9,13 +9,10 @@ class Model { public: virtual ~Model(){}; virtual void Reset() = 0; + virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0; virtual std::string ForwardChunk(float *din, int len, int flag) = 0; virtual std::string Forward(float *din, int len, int flag) = 0; virtual std::string Rescoring() = 0; - virtual std::vector> VadSeg(std::vector& pcm_data)=0; - virtual std::string AddPunc(const char* sz_input)=0; - virtual bool UseVad() =0; - virtual bool UsePunc() =0; }; Model *CreateModel(std::map& model_path,int thread_num=1); diff --git a/funasr/runtime/onnxruntime/include/offline-stream.h b/funasr/runtime/onnxruntime/include/offline-stream.h new file mode 100644 index 000000000..caa4ea62b --- /dev/null +++ b/funasr/runtime/onnxruntime/include/offline-stream.h @@ -0,0 +1,28 @@ +#ifndef OFFLINE_STREAM_H +#define OFFLINE_STREAM_H + +#include +#include +#include +#include "model.h" +#include "punc-model.h" +#include "vad-model.h" + +class OfflineStream { + public: + OfflineStream(std::map& model_path, int thread_num); + ~OfflineStream(){}; + + std::unique_ptr vad_handle; + std::unique_ptr asr_handle; + std::unique_ptr punc_handle; + bool UseVad(){return use_vad;}; + bool UsePunc(){return use_punc;}; + + private: + bool use_vad=false; + bool use_punc=false; +}; + +OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num=1); +#endif diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h new file mode 100644 index 000000000..0bb353abe --- /dev/null +++ b/funasr/runtime/onnxruntime/include/punc-model.h @@ -0,0 +1,18 @@ + +#ifndef PUNC_MODEL_H +#define PUNC_MODEL_H + +#include +#include +#include + +class PuncModel { + public: + virtual ~PuncModel(){}; + virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0; + virtual std::vector Infer(std::vector input_data)=0; + virtual std::string AddPunc(const char* sz_input)=0; +}; + +PuncModel *CreatePuncModel(std::map& model_path, int thread_num); +#endif diff --git a/funasr/runtime/onnxruntime/include/vad-model.h b/funasr/runtime/onnxruntime/include/vad-model.h new file mode 100644 index 000000000..646a1e954 --- /dev/null +++ b/funasr/runtime/onnxruntime/include/vad-model.h @@ -0,0 +1,27 @@ + +#ifndef VAD_MODEL_H +#define VAD_MODEL_H + +#include +#include +#include + +class VadModel { + public: + virtual ~VadModel(){}; + virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0; + virtual std::vector> Infer(const std::vector &waves)=0; + virtual void ReadModel(const char* vad_model)=0; + virtual void LoadConfigFromYaml(const char* filename)=0; + virtual void FbankKaldi(float sample_rate, std::vector> &vad_feats, + const std::vector &waves)=0; + virtual std::vector> &LfrCmvn(std::vector> &vad_feats)=0; + virtual void Forward( + const std::vector> &chunk_feats, + std::vector> *out_prob)=0; + virtual void LoadCmvn(const char *filename)=0; + virtual void InitCache()=0; +}; + +VadModel *CreateVadModel(std::map& model_path, int thread_num); +#endif diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt index 28a67b4be..341a16a7a 100644 --- a/funasr/runtime/onnxruntime/src/CMakeLists.txt +++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt @@ -26,7 +26,11 @@ include_directories(${CMAKE_SOURCE_DIR}/include) target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS}) add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp") +add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp") +add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp") add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp") target_link_libraries(funasr-onnx-offline PUBLIC funasr) +target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr) +target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp index d104500d1..6113614ed 100644 --- a/funasr/runtime/onnxruntime/src/audio.cpp +++ b/funasr/runtime/onnxruntime/src/audio.cpp @@ -237,6 +237,15 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate) LOG(ERROR) << "Failed to read " << filename; return false; } + + if (!header.Validate()) { + return false; + } + + header.SeekToDataChunk(is); + if (!is) { + return false; + } *sampling_rate = header.sample_rate; // header.subchunk2_size contains the number of bytes in the data. @@ -494,7 +503,7 @@ void Audio::Padding() delete frame; } -void Audio::Split(Model* recog_obj) +void Audio::Split(OfflineStream* offline_stream) { AudioFrame *frame; @@ -505,7 +514,7 @@ void Audio::Split(Model* recog_obj) frame = NULL; std::vector pcm_data(speech_data, speech_data+sp_len); - vector> vad_segments = recog_obj->VadSeg(pcm_data); + vector> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data); int seg_sample = MODEL_SAMPLE_RATE/1000; for(vector segment:vad_segments) { @@ -518,3 +527,18 @@ void Audio::Split(Model* recog_obj) frame = NULL; } } + + +void Audio::Split(VadModel* vad_obj, vector>& vad_segments) +{ + AudioFrame *frame; + + frame = frame_queue.front(); + frame_queue.pop(); + int sp_len = frame->GetLen(); + delete frame; + frame = NULL; + + std::vector pcm_data(speech_data, speech_data+sp_len); + vad_segments = vad_obj->Infer(pcm_data); +} \ No newline at end of file diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h index fbbda74e9..d5298c31c 100644 --- a/funasr/runtime/onnxruntime/src/commonfunc.h +++ b/funasr/runtime/onnxruntime/src/commonfunc.h @@ -6,6 +6,12 @@ typedef struct float snippet_time; }FUNASR_RECOG_RESULT; +typedef struct +{ + std::vector>* segments; + float snippet_time; +}FUNASR_VAD_RESULT; + #ifdef _WIN32 #include diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp index ecde636ab..91e795c45 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp +++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp @@ -54,7 +54,7 @@ string CTTransformer::AddPunc(const char* sz_input) int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN); int nCurBatch = -1; int nSentEnd = -1, nLastCommaIndex = -1; - vector RemainIDs; // + vector RemainIDs; // vector RemainStr; // vector NewPunctuation; // vector NewString; // @@ -64,7 +64,7 @@ string CTTransformer::AddPunc(const char* sz_input) for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN) { nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size()); - vector InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); + vector InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); vector InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff); InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs; InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr; @@ -141,12 +141,13 @@ string CTTransformer::AddPunc(const char* sz_input) return strResult; } -vector CTTransformer::Infer(vector input_data) +vector CTTransformer::Infer(vector input_data) { Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); vector punction; std::array input_shape_{ 1, (int64_t)input_data.size()}; - Ort::Value onnx_input = Ort::Value::CreateTensor(m_memoryInfo, + Ort::Value onnx_input = Ort::Value::CreateTensor( + m_memoryInfo, input_data.data(), input_data.size(), input_shape_.data(), diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.h b/funasr/runtime/onnxruntime/src/ct-transformer.h index d965bb33a..cff4f4747 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer.h +++ b/funasr/runtime/onnxruntime/src/ct-transformer.h @@ -5,7 +5,7 @@ #pragma once -class CTTransformer { +class CTTransformer : public PuncModel { /** * Author: Speech Lab of DAMO Academy, Alibaba Group * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection @@ -27,6 +27,6 @@ public: CTTransformer(); void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num); ~CTTransformer(); - vector Infer(vector input_data); + vector Infer(vector input_data); string AddPunc(const char* sz_input); }; diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp index fbb682b69..b1b0e639c 100644 --- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp +++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp @@ -6,8 +6,8 @@ #include #include "precomp.h" -void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) { - session_options_.SetIntraOpNumThreads(1); +void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) { + session_options_.SetIntraOpNumThreads(thread_num); session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL); session_options_.DisableCpuMemArena(); @@ -296,5 +296,8 @@ void FsmnVad::Reset(){ void FsmnVad::Test() { } +FsmnVad::~FsmnVad() { +} + FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} { } diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h index 1d5f68c56..cf03ce91a 100644 --- a/funasr/runtime/onnxruntime/src/fsmn-vad.h +++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h @@ -8,7 +8,7 @@ #include "precomp.h" -class FsmnVad { +class FsmnVad : public VadModel { /** * Author: Speech Lab of DAMO Academy, Alibaba Group * Deep-FSMN for Large Vocabulary Continuous Speech Recognition @@ -17,9 +17,9 @@ class FsmnVad { public: FsmnVad(); + ~FsmnVad(); void Test(); - void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config); - + void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num); std::vector> Infer(const std::vector &waves); void Reset(); diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp new file mode 100644 index 000000000..e8f221f6f --- /dev/null +++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp @@ -0,0 +1,98 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#ifndef _WIN32 +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include "libfunasrapi.h" +#include "tclap/CmdLine.h" +#include "com-define.h" + +using namespace std; + +void GetValue(TCLAP::ValueArg& value_arg, string key, std::map& model_path) +{ + if (value_arg.isSet()){ + model_path.insert({key, value_arg.getValue()}); + LOG(INFO)<< key << " : " << value_arg.getValue(); + } +} + +int main(int argc, char *argv[]) +{ + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = true; + + TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0"); + TCLAP::ValueArg model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string"); + TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); + TCLAP::ValueArg txt_path("", TXT_PATH, "txt file path, one sentence per line", false, "", "string"); + + cmd.add(model_dir); + cmd.add(quantize); + cmd.add(txt_path); + cmd.parse(argc, argv); + + std::map model_path; + GetValue(model_dir, MODEL_DIR, model_path); + GetValue(quantize, QUANTIZE, model_path); + GetValue(txt_path, TXT_PATH, model_path); + + struct timeval start, end; + gettimeofday(&start, NULL); + int thread_num = 1; + FUNASR_HANDLE punc_hanlde=FunPuncInit(model_path, thread_num); + + if (!punc_hanlde) + { + LOG(ERROR) << "FunASR init failed"; + exit(-1); + } + + gettimeofday(&end, NULL); + long seconds = (end.tv_sec - start.tv_sec); + long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; + + // read txt_path + vector txt_list; + + if(model_path.find(TXT_PATH)!=model_path.end()){ + ifstream in(model_path.at(TXT_PATH)); + if (!in.is_open()) { + LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ; + return 0; + } + string line; + while(getline(in, line)) + { + txt_list.emplace_back(line); + } + in.close(); + } + + long taking_micros = 0; + for(auto& txt_str : txt_list){ + gettimeofday(&start, NULL); + string result=FunPuncInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL); + gettimeofday(&end, NULL); + seconds = (end.tv_sec - start.tv_sec); + taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + LOG(INFO)<<"Results: "< vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string"); - TCLAP::ValueArg vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string"); - TCLAP::ValueArg vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string"); - - TCLAP::ValueArg am_model("", AM_MODEL_PATH, "am model path", false, "", "string"); - TCLAP::ValueArg am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string"); - TCLAP::ValueArg am_config("", AM_CONFIG_PATH, "am config path", false, "", "string"); - - TCLAP::ValueArg punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string"); - TCLAP::ValueArg punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string"); + TCLAP::ValueArg model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); + TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); TCLAP::ValueArg wav_scp("", WAV_SCP, "wave scp path", true, "", "string"); TCLAP::ValueArg thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t"); - cmd.add(vad_model); - cmd.add(vad_cmvn); - cmd.add(vad_config); - cmd.add(am_model); - cmd.add(am_cmvn); - cmd.add(am_config); - cmd.add(punc_model); - cmd.add(punc_config); + cmd.add(model_dir); + cmd.add(quantize); cmd.add(wav_scp); cmd.add(thread_num); cmd.parse(argc, argv); std::map model_path; - GetValue(vad_model, VAD_MODEL_PATH, model_path); - GetValue(vad_cmvn, VAD_CMVN_PATH, model_path); - GetValue(vad_config, VAD_CONFIG_PATH, model_path); - GetValue(am_model, AM_MODEL_PATH, model_path); - GetValue(am_cmvn, AM_CMVN_PATH, model_path); - GetValue(am_config, AM_CONFIG_PATH, model_path); - GetValue(punc_model, PUNC_MODEL_PATH, model_path); - GetValue(punc_config, PUNC_CONFIG_PATH, model_path); + GetValue(model_dir, MODEL_DIR, model_path); + GetValue(quantize, QUANTIZE, model_path); GetValue(wav_scp, WAV_SCP, model_path); struct timeval start, end; diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp new file mode 100644 index 000000000..278753484 --- /dev/null +++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp @@ -0,0 +1,143 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#ifndef _WIN32 +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include "libfunasrapi.h" +#include "tclap/CmdLine.h" +#include "com-define.h" + +using namespace std; + +void GetValue(TCLAP::ValueArg& value_arg, string key, std::map& model_path) +{ + if (value_arg.isSet()){ + model_path.insert({key, value_arg.getValue()}); + LOG(INFO)<< key << " : " << value_arg.getValue(); + } +} + +void print_segs(vector>* vec) { + string seg_out="["; + for (int i = 0; i < vec->size(); i++) { + vector inner_vec = (*vec)[i]; + seg_out += "["; + for (int j = 0; j < inner_vec.size(); j++) { + seg_out += to_string(inner_vec[j]); + if (j != inner_vec.size() - 1) { + seg_out += ","; + } + } + seg_out += "]"; + if (i != vec->size() - 1) { + seg_out += ","; + } + } + seg_out += "]"; + LOG(INFO)< model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); + TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); + + TCLAP::ValueArg wav_path("", WAV_PATH, "wave file path", false, "", "string"); + TCLAP::ValueArg wav_scp("", WAV_SCP, "wave scp path", false, "", "string"); + + cmd.add(model_dir); + cmd.add(quantize); + cmd.add(wav_path); + cmd.add(wav_scp); + cmd.parse(argc, argv); + + std::map model_path; + GetValue(model_dir, MODEL_DIR, model_path); + GetValue(quantize, QUANTIZE, model_path); + GetValue(wav_path, WAV_PATH, model_path); + GetValue(wav_scp, WAV_SCP, model_path); + + struct timeval start, end; + gettimeofday(&start, NULL); + int thread_num = 1; + FUNASR_HANDLE vad_hanlde=FunVadInit(model_path, thread_num); + + if (!vad_hanlde) + { + LOG(ERROR) << "FunVad init failed"; + exit(-1); + } + + gettimeofday(&end, NULL); + long seconds = (end.tv_sec - start.tv_sec); + long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; + + // read wav_path and wav_scp + vector wav_list; + + if(model_path.find(WAV_PATH)!=model_path.end()){ + wav_list.emplace_back(model_path.at(WAV_PATH)); + } + if(model_path.find(WAV_SCP)!=model_path.end()){ + ifstream in(model_path.at(WAV_SCP)); + if (!in.is_open()) { + LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ; + return 0; + } + string line; + while(getline(in, line)) + { + istringstream iss(line); + string column1, column2; + iss >> column1 >> column2; + wav_list.emplace_back(column2); + } + in.close(); + } + + float snippet_time = 0.0f; + long taking_micros = 0; + for(auto& wav_file : wav_list){ + gettimeofday(&start, NULL); + FUNASR_RESULT result=FunVadWavFile(vad_hanlde, wav_file.c_str(), RASR_NONE, NULL); + gettimeofday(&end, NULL); + seconds = (end.tv_sec - start.tv_sec); + taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + + if (result) + { + vector>* vad_segments = FunVadGetResult(result, 0); + print_segs(vad_segments); + snippet_time += FunVadGetRetSnippetTime(result); + FunVadFreeResult(result); + } + else + { + LOG(ERROR) << ("No return data!\n"); + } + } + + LOG(INFO) << "Audio length: " << (double)snippet_time << " s"; + LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; + LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000); + FunVadUninit(vad_hanlde); + return 0; +} + diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp index 2d61bbb30..af6d0e3ce 100644 --- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp +++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp @@ -28,55 +28,46 @@ void GetValue(TCLAP::ValueArg& value_arg, string key, std::map vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string"); - TCLAP::ValueArg vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string"); - TCLAP::ValueArg vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string"); - - TCLAP::ValueArg am_model("", AM_MODEL_PATH, "am model path", true, "", "string"); - TCLAP::ValueArg am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string"); - TCLAP::ValueArg am_config("", AM_CONFIG_PATH, "am config path", true, "", "string"); - - TCLAP::ValueArg punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string"); - TCLAP::ValueArg punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string"); + TCLAP::ValueArg model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); + TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); + TCLAP::ValueArg vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); + TCLAP::ValueArg vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); + TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); + TCLAP::ValueArg punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); TCLAP::ValueArg wav_path("", WAV_PATH, "wave file path", false, "", "string"); TCLAP::ValueArg wav_scp("", WAV_SCP, "wave scp path", false, "", "string"); - cmd.add(vad_model); - cmd.add(vad_cmvn); - cmd.add(vad_config); - cmd.add(am_model); - cmd.add(am_cmvn); - cmd.add(am_config); - cmd.add(punc_model); - cmd.add(punc_config); + cmd.add(model_dir); + cmd.add(quantize); + cmd.add(vad_dir); + cmd.add(vad_quant); + cmd.add(punc_dir); + cmd.add(punc_quant); cmd.add(wav_path); cmd.add(wav_scp); cmd.parse(argc, argv); std::map model_path; - GetValue(vad_model, VAD_MODEL_PATH, model_path); - GetValue(vad_cmvn, VAD_CMVN_PATH, model_path); - GetValue(vad_config, VAD_CONFIG_PATH, model_path); - GetValue(am_model, AM_MODEL_PATH, model_path); - GetValue(am_cmvn, AM_CMVN_PATH, model_path); - GetValue(am_config, AM_CONFIG_PATH, model_path); - GetValue(punc_model, PUNC_MODEL_PATH, model_path); - GetValue(punc_config, PUNC_CONFIG_PATH, model_path); + GetValue(model_dir, MODEL_DIR, model_path); + GetValue(quantize, QUANTIZE, model_path); + GetValue(vad_dir, VAD_DIR, model_path); + GetValue(vad_quant, VAD_QUANT, model_path); + GetValue(punc_dir, PUNC_DIR, model_path); + GetValue(punc_quant, PUNC_QUANT, model_path); GetValue(wav_path, WAV_PATH, model_path); GetValue(wav_scp, WAV_SCP, model_path); - struct timeval start, end; gettimeofday(&start, NULL); int thread_num = 1; - FUNASR_HANDLE asr_hanlde=FunASRInit(model_path, thread_num); + FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num); if (!asr_hanlde) { @@ -116,7 +107,7 @@ int main(int argc, char *argv[]) long taking_micros = 0; for(auto& wav_file : wav_list){ gettimeofday(&start, NULL); - FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL); + FUNASR_RESULT result=FunOfflineStream(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL); gettimeofday(&end, NULL); seconds = (end.tv_sec - start.tv_sec); taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); @@ -124,8 +115,7 @@ int main(int argc, char *argv[]) if (result) { string msg = FunASRGetResult(result, 0); - setbuf(stdout, NULL); - printf("Result: %s \n", msg.c_str()); + LOG(INFO)<<"Result: "<& model_path, int thread_num) { Model* mm = CreateModel(model_path, thread_num); @@ -13,10 +13,23 @@ extern "C" { _FUNASRAPI FUNASR_HANDLE FunVadInit(std::map& model_path, int thread_num) { - Model* mm = CreateModel(model_path, thread_num); + VadModel* mm = CreateVadModel(model_path, thread_num); return mm; } + _FUNASRAPI FUNASR_HANDLE FunPuncInit(std::map& model_path, int thread_num) + { + PuncModel* mm = CreatePuncModel(model_path, thread_num); + return mm; + } + + _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num) + { + OfflineStream* mm = CreateOfflineStream(model_path, thread_num); + return mm; + } + + // APIs for ASR Infer _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback) { Model* recog_obj = (Model*)handle; @@ -27,9 +40,6 @@ extern "C" { Audio audio(1); if (!audio.LoadWav(sz_buf, n_len, &sampling_rate)) return nullptr; - if(recog_obj->UseVad()){ - audio.Split(recog_obj); - } float* buff; int len; @@ -45,10 +55,6 @@ extern "C" { if (fn_callback) fn_callback(n_step, n_total); } - if(recog_obj->UsePunc()){ - string punc_res = recog_obj->AddPunc((p_result->msg).c_str()); - p_result->msg = punc_res; - } return p_result; } @@ -62,9 +68,6 @@ extern "C" { Audio audio(1); if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate)) return nullptr; - if(recog_obj->UseVad()){ - audio.Split(recog_obj); - } float* buff; int len; @@ -80,10 +83,6 @@ extern "C" { if (fn_callback) fn_callback(n_step, n_total); } - if(recog_obj->UsePunc()){ - string punc_res = recog_obj->AddPunc((p_result->msg).c_str()); - p_result->msg = punc_res; - } return p_result; } @@ -97,9 +96,6 @@ extern "C" { Audio audio(1); if (!audio.LoadPcmwav(sz_filename, &sampling_rate)) return nullptr; - if(recog_obj->UseVad()){ - audio.Split(recog_obj); - } float* buff; int len; @@ -115,10 +111,6 @@ extern "C" { if (fn_callback) fn_callback(n_step, n_total); } - if(recog_obj->UsePunc()){ - string punc_res = recog_obj->AddPunc((p_result->msg).c_str()); - p_result->msg = punc_res; - } return p_result; } @@ -133,9 +125,6 @@ extern "C" { Audio audio(1); if(!audio.LoadWav(sz_wavfile, &sampling_rate)) return nullptr; - if(recog_obj->UseVad()){ - audio.Split(recog_obj); - } float* buff; int len; @@ -151,8 +140,74 @@ extern "C" { if (fn_callback) fn_callback(n_step, n_total); } - if(recog_obj->UsePunc()){ - string punc_res = recog_obj->AddPunc((p_result->msg).c_str()); + + return p_result; + } + + // APIs for VAD Infer + _FUNASRAPI FUNASR_RESULT FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback) + { + VadModel* vad_obj = (VadModel*)handle; + if (!vad_obj) + return nullptr; + + int32_t sampling_rate = -1; + Audio audio(1); + if(!audio.LoadWav(sz_wavfile, &sampling_rate)) + return nullptr; + + FUNASR_VAD_RESULT* p_result = new FUNASR_VAD_RESULT; + p_result->snippet_time = audio.GetTimeLen(); + + vector> vad_segments; + audio.Split(vad_obj, vad_segments); + p_result->segments = new vector>(vad_segments); + + return p_result; + } + + // APIs for PUNC Infer + _FUNASRAPI const std::string FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback) + { + PuncModel* punc_obj = (PuncModel*)handle; + if (!punc_obj) + return nullptr; + + string punc_res = punc_obj->AddPunc(sz_sentence); + return punc_res; + } + + // APIs for Offline-stream Infer + _FUNASRAPI FUNASR_RESULT FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback) + { + OfflineStream* offline_stream = (OfflineStream*)handle; + if (!offline_stream) + return nullptr; + + int32_t sampling_rate = -1; + Audio audio(1); + if(!audio.LoadWav(sz_wavfile, &sampling_rate)) + return nullptr; + if(offline_stream->UseVad()){ + audio.Split(offline_stream); + } + + float* buff; + int len; + int flag = 0; + int n_step = 0; + int n_total = audio.GetQueueSize(); + FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT; + p_result->snippet_time = audio.GetTimeLen(); + while (audio.Fetch(buff, len, flag) > 0) { + string msg = (offline_stream->asr_handle)->Forward(buff, len, flag); + p_result->msg+= msg; + n_step++; + if (fn_callback) + fn_callback(n_step, n_total); + } + if(offline_stream->UsePunc()){ + string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str()); p_result->msg = punc_res; } @@ -167,7 +222,7 @@ extern "C" { return 1; } - + // APIs for GetRetSnippetTime _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result) { if (!result) @@ -176,6 +231,15 @@ extern "C" { return ((FUNASR_RECOG_RESULT*)result)->snippet_time; } + _FUNASRAPI const float FunVadGetRetSnippetTime(FUNASR_RESULT result) + { + if (!result) + return 0.0f; + + return ((FUNASR_VAD_RESULT*)result)->snippet_time; + } + + // APIs for GetResult _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index) { FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result; @@ -185,6 +249,16 @@ extern "C" { return p_result->msg.c_str(); } + _FUNASRAPI vector>* FunVadGetResult(FUNASR_RESULT result,int n_index) + { + FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result; + if(!p_result) + return nullptr; + + return p_result->segments; + } + + // APIs for FreeResult _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result) { if (result) @@ -193,6 +267,19 @@ extern "C" { } } + _FUNASRAPI void FunVadFreeResult(FUNASR_RESULT result) + { + FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result; + if (p_result) + { + if(p_result->segments){ + delete p_result->segments; + } + delete p_result; + } + } + + // APIs for Uninit _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle) { Model* recog_obj = (Model*)handle; @@ -203,6 +290,36 @@ extern "C" { delete recog_obj; } + _FUNASRAPI void FunVadUninit(FUNASR_HANDLE handle) + { + VadModel* recog_obj = (VadModel*)handle; + + if (!recog_obj) + return; + + delete recog_obj; + } + + _FUNASRAPI void FunPuncUninit(FUNASR_HANDLE handle) + { + PuncModel* punc_obj = (PuncModel*)handle; + + if (!punc_obj) + return; + + delete punc_obj; + } + + _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle) + { + OfflineStream* offline_stream = (OfflineStream*)handle; + + if (!offline_stream) + return; + + delete offline_stream; + } + #ifdef __cplusplus } diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp index 52ce7ba7c..65ea172f0 100644 --- a/funasr/runtime/onnxruntime/src/model.cpp +++ b/funasr/runtime/onnxruntime/src/model.cpp @@ -2,7 +2,19 @@ Model *CreateModel(std::map& model_path, int thread_num) { + string am_model_path; + string am_cmvn_path; + string am_config_path; + + am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); + if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ + am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); + } + am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME); + am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME); + Model *mm; - mm = new paraformer::Paraformer(model_path, thread_num); + mm = new paraformer::Paraformer(); + mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num); return mm; } diff --git a/funasr/runtime/onnxruntime/src/offline-stream.cpp b/funasr/runtime/onnxruntime/src/offline-stream.cpp new file mode 100644 index 000000000..00c131844 --- /dev/null +++ b/funasr/runtime/onnxruntime/src/offline-stream.cpp @@ -0,0 +1,61 @@ +#include "precomp.h" + +OfflineStream::OfflineStream(std::map& model_path, int thread_num) +{ + // VAD model + if(model_path.find(VAD_DIR) != model_path.end()){ + use_vad = true; + string vad_model_path; + string vad_cmvn_path; + string vad_config_path; + + vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME); + if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){ + vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME); + } + vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME); + vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME); + vad_handle = make_unique(); + vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num); + } + + // AM model + if(model_path.find(MODEL_DIR) != model_path.end()){ + string am_model_path; + string am_cmvn_path; + string am_config_path; + + am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); + if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ + am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); + } + am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME); + am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME); + + asr_handle = make_unique(); + asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num); + } + + // PUNC model + if(model_path.find(PUNC_DIR) != model_path.end()){ + use_punc = true; + string punc_model_path; + string punc_config_path; + + punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME); + if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){ + punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME); + } + punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME); + + punc_handle = make_unique(); + punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num); + } +} + +OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num) +{ + OfflineStream *mm; + mm = new OfflineStream(model_path, thread_num); + return mm; +} diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp index 136d22808..244a706db 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.cpp +++ b/funasr/runtime/onnxruntime/src/paraformer.cpp @@ -8,65 +8,11 @@ using namespace std; using namespace paraformer; -Paraformer::Paraformer(std::map& model_path,int thread_num) +Paraformer::Paraformer() :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{ - - // VAD model - if(model_path.find(VAD_MODEL_PATH) != model_path.end()){ - use_vad = true; - string vad_model_path; - string vad_cmvn_path; - string vad_config_path; - - try{ - vad_model_path = model_path.at(VAD_MODEL_PATH); - vad_cmvn_path = model_path.at(VAD_CMVN_PATH); - vad_config_path = model_path.at(VAD_CONFIG_PATH); - }catch(const out_of_range& e){ - LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what(); - exit(0); - } - vad_handle = make_unique(); - vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path); - } - - // AM model - if(model_path.find(AM_MODEL_PATH) != model_path.end()){ - string am_model_path; - string am_cmvn_path; - string am_config_path; - - try{ - am_model_path = model_path.at(AM_MODEL_PATH); - am_cmvn_path = model_path.at(AM_CMVN_PATH); - am_config_path = model_path.at(AM_CONFIG_PATH); - }catch(const out_of_range& e){ - LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what(); - exit(0); - } - InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num); - } - - // PUNC model - if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){ - use_punc = true; - string punc_model_path; - string punc_config_path; - - try{ - punc_model_path = model_path.at(PUNC_MODEL_PATH); - punc_config_path = model_path.at(PUNC_CONFIG_PATH); - }catch(const out_of_range& e){ - LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what(); - exit(0); - } - - punc_handle = make_unique(); - punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num); - } } -void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){ +void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){ // knf options fbank_opts.frame_opts.dither = 0; fbank_opts.mel_opts.num_bins = 80; @@ -120,14 +66,6 @@ void Paraformer::Reset() { } -vector> Paraformer::VadSeg(std::vector& pcm_data){ - return vad_handle->Infer(pcm_data); -} - -string Paraformer::AddPunc(const char* sz_input){ - return punc_handle->AddPunc(sz_input); -} - vector Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) { knf::OnlineFbank fbank_(fbank_opts); fbank_.AcceptWaveform(sample_rate, waves, len); @@ -282,7 +220,7 @@ string Paraformer::Forward(float* din, int len, int flag) } catch (std::exception const &e) { - printf(e.what()); + LOG(ERROR)< fbank_; knf::FbankOptions fbank_opts; - std::unique_ptr vad_handle; - std::unique_ptr punc_handle; - Vocab* vocab; vector means_list; vector vars_list; @@ -36,7 +31,6 @@ namespace paraformer { void LoadCmvn(const char *filename); vector ApplyLfr(const vector &in); void ApplyCmvn(vector *v); - string GreedySearch( float* in, int n_len, int64_t token_nums); std::shared_ptr m_session; @@ -46,22 +40,16 @@ namespace paraformer { vector m_strInputNames, m_strOutputNames; vector m_szInputNames; vector m_szOutputNames; - bool use_vad=false; - bool use_punc=false; public: - Paraformer(std::map& model_path, int thread_num=0); + Paraformer(); ~Paraformer(); - void InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num); + void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num); void Reset(); vector FbankKaldi(float sample_rate, const float* waves, int len); string ForwardChunk(float* din, int len, int flag); string Forward(float* din, int len, int flag); string Rescoring(); - std::vector> VadSeg(std::vector& pcm_data); - string AddPunc(const char* sz_input); - bool UseVad(){return use_vad;}; - bool UsePunc(){return use_punc;}; }; } // namespace paraformer diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h index 68e0fe840..0d3199ee7 100644 --- a/funasr/runtime/onnxruntime/src/precomp.h +++ b/funasr/runtime/onnxruntime/src/precomp.h @@ -30,6 +30,10 @@ using namespace std; #include "com-define.h" #include "commonfunc.h" #include "predefine-coe.h" +#include "model.h" +#include "vad-model.h" +#include "punc-model.h" +#include "offline-stream.h" #include "tokenizer.h" #include "ct-transformer.h" #include "fsmn-vad.h" @@ -39,9 +43,8 @@ using namespace std; #include "tensor.h" #include "util.h" #include "resample.h" -#include "model.h" -#include "vad-model.h" #include "paraformer.h" +#include "offline-stream.h" #include "libfunasrapi.h" using namespace paraformer; diff --git a/funasr/runtime/onnxruntime/src/punc-model.cpp b/funasr/runtime/onnxruntime/src/punc-model.cpp new file mode 100644 index 000000000..1e619ab9d --- /dev/null +++ b/funasr/runtime/onnxruntime/src/punc-model.cpp @@ -0,0 +1,19 @@ +#include "precomp.h" + +PuncModel *CreatePuncModel(std::map& model_path, int thread_num) +{ + PuncModel *mm; + mm = new CTTransformer(); + + string punc_model_path; + string punc_config_path; + + punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); + if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ + punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); + } + punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME); + + mm->InitPunc(punc_model_path, punc_config_path, thread_num); + return mm; +} diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp index 5f29b46f0..5aff058b3 100644 --- a/funasr/runtime/onnxruntime/src/tokenizer.cpp +++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp @@ -14,6 +14,10 @@ CTokenizer::CTokenizer():m_ready(false) { } +CTokenizer::~CTokenizer() +{ +} + void CTokenizer::ReadYaml(const YAML::Node& node) { if (node.IsMap()) diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h index 4ff1809cf..4ddd359e5 100644 --- a/funasr/runtime/onnxruntime/src/tokenizer.h +++ b/funasr/runtime/onnxruntime/src/tokenizer.h @@ -17,6 +17,7 @@ public: CTokenizer(const char* sz_yamlfile); CTokenizer(); + ~CTokenizer(); bool OpenYaml(const char* sz_yamlfile); void ReadYaml(const YAML::Node& node); vector Id2String(vector input); diff --git a/funasr/runtime/onnxruntime/src/vad-model.cpp b/funasr/runtime/onnxruntime/src/vad-model.cpp new file mode 100644 index 000000000..0a0ec84eb --- /dev/null +++ b/funasr/runtime/onnxruntime/src/vad-model.cpp @@ -0,0 +1,21 @@ +#include "precomp.h" + +VadModel *CreateVadModel(std::map& model_path, int thread_num) +{ + VadModel *mm; + mm = new FsmnVad(); + + string vad_model_path; + string vad_cmvn_path; + string vad_config_path; + + vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); + if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ + vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); + } + vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), VAD_CMVN_NAME); + vad_config_path = PathAppend(model_path.at(MODEL_DIR), VAD_CONFIG_NAME); + + mm->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num); + return mm; +}