Merge pull request #273 from alibaba-damo-academy/dev_onnx

Dev onnx
This commit is contained in:
zhifu gao 2023-03-21 17:13:32 +08:00 committed by GitHub
commit 4986ec2dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 148 additions and 95 deletions

View File

@ -44,8 +44,8 @@ source ~/.bashrc
#### Step 4. Start grpc paraformer server #### Step 4. Start grpc paraformer server
``` ```
Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch ./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
``` ```

View File

@ -29,8 +29,8 @@ using paraformer::Request;
using paraformer::Response; using paraformer::Response;
using paraformer::ASR; using paraformer::ASR;
ASRServicer::ASRServicer(const char* model_path, int thread_num) { ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
AsrHanlde=RapidAsrInit(model_path, thread_num); AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
std::cout << "ASRServicer init" << std::endl; std::cout << "ASRServicer init" << std::endl;
init_flag = 0; init_flag = 0;
} }
@ -170,10 +170,10 @@ grpc::Status ASRServicer::Recognize(
} }
void RunServer(const std::string& port, int thread_num, const char* model_path) { void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
std::string server_address; std::string server_address;
server_address = "0.0.0.0:" + port; server_address = "0.0.0.0:" + port;
ASRServicer service(model_path, thread_num); ASRServicer service(model_path, thread_num, quantize);
ServerBuilder builder; ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
@ -184,12 +184,15 @@ void RunServer(const std::string& port, int thread_num, const char* model_path)
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc < 3) if (argc < 5)
{ {
printf("Usage: %s port thread_num /path/to/model_file\n", argv[0]); printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
exit(-1); exit(-1);
} }
RunServer(argv[1], atoi(argv[2]), argv[3]); // is quantize
bool quantize = false;
std::istringstream(argv[4]) >> std::boolalpha >> quantize;
RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
return 0; return 0;
} }

View File

@ -45,7 +45,7 @@ class ASRServicer final : public ASR::Service {
std::unordered_map<std::string, std::string> client_transcription; std::unordered_map<std::string, std::string> client_transcription;
public: public:
ASRServicer(const char* model_path, int thread_num); ASRServicer(const char* model_path, int thread_num, bool quantize);
void clear_states(const std::string& user); void clear_states(const std::string& user);
void clear_buffers(const std::string& user); void clear_buffers(const std::string& user);
void clear_transcriptions(const std::string& user); void clear_transcriptions(const std::string& user);

View File

@ -13,5 +13,5 @@ class Model {
virtual std::string rescoring() = 0; virtual std::string rescoring() = 0;
}; };
Model *create_model(const char *path,int nThread=0); Model *create_model(const char *path,int nThread=0,bool quantize=false);
#endif #endif

View File

@ -1,33 +1,20 @@
#pragma once #pragma once
#ifdef WIN32 #ifdef WIN32
#ifdef _RPASR_API_EXPORT #ifdef _RPASR_API_EXPORT
#define _RAPIDASRAPI __declspec(dllexport) #define _RAPIDASRAPI __declspec(dllexport)
#else #else
#define _RAPIDASRAPI __declspec(dllimport) #define _RAPIDASRAPI __declspec(dllimport)
#endif #endif
#else #else
#define _RAPIDASRAPI #define _RAPIDASRAPI
#endif #endif
#ifndef _WIN32 #ifndef _WIN32
#define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__)) #define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
#else #else
#define RPASR_CALLBCK_PREFIX __stdcall #define RPASR_CALLBCK_PREFIX __stdcall
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
@ -35,16 +22,13 @@ extern "C" {
#endif #endif
typedef void* RPASR_HANDLE; typedef void* RPASR_HANDLE;
typedef void* RPASR_RESULT; typedef void* RPASR_RESULT;
typedef unsigned char RPASR_BOOL; typedef unsigned char RPASR_BOOL;
#define RPASR_TRUE 1 #define RPASR_TRUE 1
#define RPASR_FALSE 0 #define RPASR_FALSE 0
#define QM_DEFAULT_THREAD_NUM 4 #define QM_DEFAULT_THREAD_NUM 4
typedef enum typedef enum
{ {
RASR_NONE=-1, RASR_NONE=-1,
@ -55,7 +39,6 @@ typedef enum
}RPASR_MODE; }RPASR_MODE;
typedef enum { typedef enum {
RPASR_MODEL_PADDLE = 0, RPASR_MODEL_PADDLE = 0,
RPASR_MODEL_PADDLE_2 = 1, RPASR_MODEL_PADDLE_2 = 1,
RPASR_MODEL_K2 = 2, RPASR_MODEL_K2 = 2,
@ -63,17 +46,15 @@ typedef enum {
}RPASR_MODEL_TYPE; }RPASR_MODEL_TYPE;
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step. typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
// APIs for qmasr // APIs for qmasr
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThread);
// if not give a fnCallback ,it should be NULL // if not give a fnCallback ,it should be NULL
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback); _RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback); _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback); _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
@ -83,8 +64,8 @@ _RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szW
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex); _RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result); _RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result);
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result);
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result);
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE Handle); _RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE Handle);

View File

@ -16,9 +16,9 @@ See the bottom of this page: Building Guidance
### 运行程序 ### 运行程序
tester /path/to/models/dir /path/to/wave/file tester /path/to/models/dir /path/to/wave/file quantize(true or false)
例如: tester /data/models /data/test.wav 例如: tester /data/models /data/test.wav false
/data/models 需要包括如下两个文件: model.onnx 和vocab.txt /data/models 需要包括如下两个文件: model.onnx 和vocab.txt

View File

@ -1,11 +1,10 @@
#include "precomp.h" #include "precomp.h"
Model *create_model(const char *path,int nThread) Model *create_model(const char *path, int nThread, bool quantize)
{ {
Model *mm; Model *mm;
mm = new paraformer::ModelImp(path, nThread, quantize);
mm = new paraformer::ModelImp(path, nThread);
return mm; return mm;
} }

View File

@ -4,24 +4,16 @@
extern "C" { extern "C" {
#endif #endif
// APIs for qmasr // APIs for qmasr
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThreadNum) _RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
{ {
Model* mm = create_model(szModelDir, nThreadNum, quantize);
Model* mm = create_model(szModelDir, nThreadNum);
return mm; return mm;
} }
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback) _RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
{ {
Model* pRecogObj = (Model*)handle; Model* pRecogObj = (Model*)handle;
if (!pRecogObj) if (!pRecogObj)
return nullptr; return nullptr;
@ -46,15 +38,12 @@ extern "C" {
fnCallback(nStep, nTotal); fnCallback(nStep, nTotal);
} }
return pResult; return pResult;
} }
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback) _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
{ {
Model* pRecogObj = (Model*)handle; Model* pRecogObj = (Model*)handle;
if (!pRecogObj) if (!pRecogObj)
return nullptr; return nullptr;
@ -79,16 +68,12 @@ extern "C" {
fnCallback(nStep, nTotal); fnCallback(nStep, nTotal);
} }
return pResult; return pResult;
} }
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback) _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
{ {
Model* pRecogObj = (Model*)handle; Model* pRecogObj = (Model*)handle;
if (!pRecogObj) if (!pRecogObj)
return nullptr; return nullptr;
@ -113,15 +98,12 @@ extern "C" {
fnCallback(nStep, nTotal); fnCallback(nStep, nTotal);
} }
return pResult; return pResult;
} }
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback) _RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
{ {
Model* pRecogObj = (Model*)handle; Model* pRecogObj = (Model*)handle;
if (!pRecogObj) if (!pRecogObj)
return nullptr; return nullptr;
@ -146,9 +128,6 @@ extern "C" {
fnCallback(nStep, nTotal); fnCallback(nStep, nTotal);
} }
return pResult; return pResult;
} }
@ -158,7 +137,6 @@ extern "C" {
return 0; return 0;
return 1; return 1;
} }
@ -168,7 +146,6 @@ extern "C" {
return 0.0f; return 0.0f;
return ((RPASR_RECOG_RESULT*)Result)->snippet_time; return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
} }
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex) _RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
@ -178,34 +155,26 @@ extern "C" {
return nullptr; return nullptr;
return pResult->msg.c_str(); return pResult->msg.c_str();
} }
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result) _RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
{ {
if (Result) if (Result)
{ {
delete (RPASR_RECOG_RESULT*)Result; delete (RPASR_RECOG_RESULT*)Result;
} }
} }
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle) _RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
{ {
Model* pRecogObj = (Model*)handle; Model* pRecogObj = (Model*)handle;
if (!pRecogObj) if (!pRecogObj)
return; return;
delete pRecogObj; delete pRecogObj;
} }
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -3,14 +3,22 @@
using namespace std; using namespace std;
using namespace paraformer; using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread) ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{ {
string model_path = pathAppend(path, "model.onnx"); string model_path;
string vocab_path = pathAppend(path, "vocab.txt"); string vocab_path;
if(quantize)
{
model_path = pathAppend(path, "model_quant.onnx");
}else{
model_path = pathAppend(path, "model.onnx");
}
vocab_path = pathAppend(path, "vocab.txt");
fe = new FeatureExtract(3); fe = new FeatureExtract(3);
sessionOptions.SetInterOpNumThreads(nNumThread); //sessionOptions.SetInterOpNumThreads(1);
sessionOptions.SetIntraOpNumThreads(nNumThread);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
#ifdef _WIN32 #ifdef _WIN32

View File

@ -4,10 +4,6 @@
#ifndef PARAFORMER_MODELIMP_H #ifndef PARAFORMER_MODELIMP_H
#define PARAFORMER_MODELIMP_H #define PARAFORMER_MODELIMP_H
namespace paraformer { namespace paraformer {
class ModelImp : public Model { class ModelImp : public Model {
@ -19,7 +15,6 @@ namespace paraformer {
void apply_lfr(Tensor<float>*& din); void apply_lfr(Tensor<float>*& din);
void apply_cmvn(Tensor<float>* din); void apply_cmvn(Tensor<float>* din);
string greedy_search( float* in, int nLen); string greedy_search( float* in, int nLen);
#ifdef _WIN_X86 #ifdef _WIN_X86
@ -39,7 +34,7 @@ namespace paraformer {
//string m_strOutputName, m_strOutputNameLen; //string m_strOutputName, m_strOutputNameLen;
public: public:
ModelImp(const char* path, int nNumThread=0); ModelImp(const char* path, int nNumThread=0, bool quantize=false);
~ModelImp(); ~ModelImp();
void reset(); void reset();
string forward_chunk(float* din, int len, int flag); string forward_chunk(float* din, int len, int flag);

View File

@ -13,8 +13,11 @@ set(EXTRA_LIBS rapidasr)
include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/include)
set(EXECNAME "tester") set(EXECNAME "tester")
set(EXECNAMERTF "tester_rtf")
add_executable(${EXECNAME} "tester.cpp") add_executable(${EXECNAME} "tester.cpp")
target_link_libraries(${EXECNAME} PUBLIC ${EXTRA_LIBS}) target_link_libraries(${EXECNAME} PUBLIC ${EXTRA_LIBS})
add_executable(${EXECNAMERTF} "tester_rtf.cpp")
target_link_libraries(${EXECNAMERTF} PUBLIC ${EXTRA_LIBS})

View File

@ -9,41 +9,40 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <sstream>
using namespace std; using namespace std;
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
if (argc < 2) if (argc < 4)
{ {
printf("Usage: %s /path/to/model_dir /path/to/wav/file", argv[0]); printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) \n", argv[0]);
exit(-1); exit(-1);
} }
struct timeval start, end; struct timeval start, end;
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
int nThreadNum = 4; int nThreadNum = 4;
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum); // is quantize
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde) if (!AsrHanlde)
{ {
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]); printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1); exit(-1);
} }
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec); long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000); printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
float snippet_time = 0.0f; float snippet_time = 0.0f;
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
@ -62,7 +61,6 @@ int main(int argc, char *argv[])
cout <<"no return data!"; cout <<"no return data!";
} }
//char* buff = nullptr; //char* buff = nullptr;
//int len = 0; //int len = 0;
//ifstream ifs(argv[2], std::ios::binary | std::ios::in); //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
@ -101,13 +99,11 @@ int main(int argc, char *argv[])
// //
//delete[]buff; //delete[]buff;
//} //}
printf("Audio length %lfs.\n", (double)snippet_time); printf("Audio length %lfs.\n", (double)snippet_time);
seconds = (end.tv_sec - start.tv_sec); seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000); printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000)); printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
RapidAsrUninit(AsrHanlde); RapidAsrUninit(AsrHanlde);

View File

@ -0,0 +1,99 @@
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include "librapidasrapi.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
using namespace std;
int main(int argc, char *argv[])
{
if (argc < 4)
{
printf("Usage: %s /path/to/model_dir /path/to/wav.scp quantize(true or false) \n", argv[0]);
exit(-1);
}
// read wav.scp
vector<string> wav_list;
ifstream in(argv[2]);
if (!in.is_open()) {
printf("Failed to open file: %s", argv[2]);
return 0;
}
string line;
while(getline(in, line))
{
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.push_back(column2);
}
in.close();
// model init
struct timeval start, end;
gettimeofday(&start, NULL);
int nThreadNum = 1;
// is quantize
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde)
{
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
// warm up
for (size_t i = 0; i < 30; i++)
{
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
}
// forward
float snippet_time = 0.0f;
float total_length = 0.0f;
long total_time = 0.0f;
for (size_t i = 0; i < wav_list.size(); i++)
{
gettimeofday(&start, NULL);
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
total_time += taking_micros;
if(Result){
string msg = RapidAsrGetResult(Result, 0);
printf("Result: %s \n", msg);
snippet_time = RapidAsrGetRetSnippetTime(Result);
total_length += snippet_time;
RapidAsrFreeResult(Result);
}else{
cout <<"No return data!";
}
}
printf("total_time_wav %ld ms.\n", (long)(total_length * 1000));
printf("total_time_comput %ld ms.\n", total_time / 1000);
printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
RapidAsrUninit(AsrHanlde);
return 0;
}