add online func for funasrruntime.cpp

This commit is contained in:
雾聪 2023-06-28 10:37:24 +08:00
parent 87c2c615e8
commit 287fd0202b
2 changed files with 37 additions and 5 deletions

View File

@ -23,9 +23,9 @@ extern "C" {
return mm;
}
_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num)
_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type)
{
funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num);
funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num, type);
return mm;
}
@ -164,14 +164,28 @@ extern "C" {
}
// APIs for PUNC Infer
_FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
_FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type, FUNASR_RESULT pre_result)
{
funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle;
if (!punc_obj)
return nullptr;
FUNASR_RESULT p_result = nullptr;
if (type==PUNC_OFFLINE){
p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence);
}else if(type==PUNC_ONLINE){
if (!pre_result)
p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
else
p_result = pre_result;
((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence, ((funasr::FUNASR_PUNC_RESULT*)p_result)->arr_cache);
}else{
LOG(ERROR) << "Wrong PUNC_TYPE";
exit(-1);
}
string punc_res = punc_obj->AddPunc(sz_sentence);
return punc_res;
return p_result;
}
// APIs for Offline-stream Infer
@ -296,6 +310,15 @@ extern "C" {
return p_result->msg.c_str();
}
_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
if(!p_result)
return nullptr;
return p_result->msg.c_str();
}
_FUNASRAPI vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
@ -314,6 +337,14 @@ extern "C" {
}
}
_FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result)
{
if (result)
{
delete (funasr::FUNASR_PUNC_RESULT*)result;
}
}
_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result)
{
funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;

View File

@ -36,6 +36,7 @@ using namespace std;
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "fsmn-vad-online.h"