rename functions

This commit is contained in:
lyblsgo 2023-04-24 15:06:03 +08:00
parent 6406616c2d
commit 35a2bfffdf
26 changed files with 417 additions and 443 deletions

View File

@ -23,11 +23,11 @@ class AudioFrame {
AudioFrame(int len);
~AudioFrame();
int set_start(int val);
int set_end(int val);
int get_start();
int get_len();
int disp();
int SetStart(int val);
int SetEnd(int val);
int GetStart();
int GetLen();
int Disp();
};
class Audio {
@ -45,19 +45,18 @@ class Audio {
Audio(int data_type);
Audio(int data_type, int size);
~Audio();
void disp();
bool loadwav(const char* filename, int32_t* sampling_rate);
void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
bool loadpcmwav(const char* filename, int32_t* sampling_rate);
int fetch_chunck(float *&dout, int len);
int fetch(float *&dout, int &len, int &flag);
void padding();
void split(Model* pRecogObj);
float get_time_len();
int get_queue_size() { return (int)frame_queue.size(); }
void Disp();
bool LoadWav(const char* filename, int32_t* sampling_rate);
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
int FetchChunck(float *&dout, int len);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
void Split(Model* recog_obj);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
};
#endif

View File

@ -35,7 +35,6 @@ typedef enum
RASRM_CTC_GREEDY_SEARCH=0,
RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
RASRM_ATTENSION_RESCORING = 2,
}FUNASR_MODE;
typedef enum {
@ -43,33 +42,31 @@ typedef enum {
FUNASR_MODEL_PADDLE_2 = 1,
FUNASR_MODEL_K2 = 2,
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// APIs for funasr
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* sz_model_dir, int thread_num, bool quantize=false, bool use_vad=false, bool use_punc=false);
// if not give a fn_callback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
// if not give a fnCallback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE Handle);
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result);
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
#ifdef __cplusplus

View File

@ -7,13 +7,13 @@
class Model {
public:
virtual ~Model(){};
virtual void reset() = 0;
virtual std::string forward_chunk(float *din, int len, int flag) = 0;
virtual std::string forward(float *din, int len, int flag) = 0;
virtual std::string rescoring() = 0;
virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
virtual std::string AddPunc(const char* szInput)=0;
virtual void Reset() = 0;
virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
virtual std::string Forward(float *din, int len, int flag) = 0;
virtual std::string Rescoring() = 0;
virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
virtual std::string AddPunc(const char* sz_input)=0;
};
Model *CreateModel(const char *path,int nThread=0,bool quantize=false, bool use_vad=false, bool use_punc=false);
Model *CreateModel(const char *path,int thread_num=1,bool quantize=false, bool use_vad=false, bool use_punc=false);
#endif

View File

@ -1,5 +1,5 @@
#include "precomp.h"
void *aligned_malloc(size_t alignment, size_t required_bytes)
void *AlignedMalloc(size_t alignment, size_t required_bytes)
{
void *p1; // original block
void **p2; // aligned block
@ -12,7 +12,7 @@ void *aligned_malloc(size_t alignment, size_t required_bytes)
return p2;
}
void aligned_free(void *p)
void AlignedFree(void *p)
{
free(((void **)p)[-1]);
}

View File

@ -4,7 +4,7 @@
extern void *aligned_malloc(size_t alignment, size_t required_bytes);
extern void aligned_free(void *p);
extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
extern void AlignedFree(void *p);
#endif

View File

@ -128,30 +128,30 @@ AudioFrame::AudioFrame(int len) : len(len)
start = 0;
};
AudioFrame::~AudioFrame(){};
int AudioFrame::set_start(int val)
int AudioFrame::SetStart(int val)
{
start = val < 0 ? 0 : val;
return start;
};
int AudioFrame::set_end(int val)
int AudioFrame::SetEnd(int val)
{
end = val;
len = end - start;
return end;
};
int AudioFrame::get_start()
int AudioFrame::GetStart()
{
return start;
};
int AudioFrame::get_len()
int AudioFrame::GetLen()
{
return len;
};
int AudioFrame::disp()
int AudioFrame::Disp()
{
printf("not imp!!!!\n");
@ -185,18 +185,18 @@ Audio::~Audio()
}
}
void Audio::disp()
void Audio::Disp()
{
printf("Audio time is %f s. len is %d\n", (float)speech_len / MODEL_SAMPLE_RATE,
speech_len);
}
float Audio::get_time_len()
float Audio::GetTimeLen()
{
return (float)speech_len / MODEL_SAMPLE_RATE;
}
void Audio::wavResample(int32_t sampling_rate, const float *waveform,
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
int32_t n)
{
printf(
@ -226,7 +226,7 @@ void Audio::wavResample(int32_t sampling_rate, const float *waveform,
copy(samples.begin(), samples.end(), speech_data);
}
bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
{
WaveHeader header;
if (speech_data != NULL) {
@ -271,7 +271,7 @@ bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
wavResample(*sampling_rate, speech_data, speech_len);
WavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@ -283,7 +283,7 @@ bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
return false;
}
bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
{
WaveHeader header;
if (speech_data != NULL) {
@ -318,7 +318,7 @@ bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
wavResample(*sampling_rate, speech_data, speech_len);
WavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@ -330,7 +330,7 @@ bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
return false;
}
bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
@ -340,7 +340,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
}
offset = 0;
speech_len = nBufLen / 2;
speech_len = n_buf_len / 2;
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
@ -361,7 +361,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
wavResample(*sampling_rate, speech_data, speech_len);
WavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@ -373,7 +373,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
return false;
}
bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
@ -388,10 +388,10 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
if (fp == nullptr)
return false;
fseek(fp, 0, SEEK_END);
uint32_t nFileLen = ftell(fp);
uint32_t n_file_len = ftell(fp);
fseek(fp, 0, SEEK_SET);
speech_len = (nFileLen) / 2;
speech_len = (n_file_len) / 2;
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
@ -412,7 +412,7 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
//resample
if(*sampling_rate != MODEL_SAMPLE_RATE){
wavResample(*sampling_rate, speech_data, speech_len);
WavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@ -425,7 +425,7 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
}
int Audio::fetch_chunck(float *&dout, int len)
int Audio::FetchChunck(float *&dout, int len)
{
if (offset >= speech_align_len) {
dout = NULL;
@ -446,14 +446,14 @@ int Audio::fetch_chunck(float *&dout, int len)
}
}
int Audio::fetch(float *&dout, int &len, int &flag)
int Audio::Fetch(float *&dout, int &len, int &flag)
{
if (frame_queue.size() > 0) {
AudioFrame *frame = frame_queue.front();
frame_queue.pop();
dout = speech_data + frame->get_start();
len = frame->get_len();
dout = speech_data + frame->GetStart();
len = frame->GetLen();
delete frame;
flag = S_END;
return 1;
@ -462,7 +462,7 @@ int Audio::fetch(float *&dout, int &len, int &flag)
}
}
void Audio::padding()
void Audio::Padding()
{
float num_samples = speech_len;
float frame_length = 400;
@ -499,26 +499,26 @@ void Audio::padding()
delete frame;
}
void Audio::split(Model* pRecogObj)
void Audio::Split(Model* recog_obj)
{
AudioFrame *frame;
frame = frame_queue.front();
frame_queue.pop();
int sp_len = frame->get_len();
int sp_len = frame->GetLen();
delete frame;
frame = NULL;
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
vector<std::vector<int>> vad_segments = pRecogObj->vad_seg(pcm_data);
vector<std::vector<int>> vad_segments = recog_obj->VadSeg(pcm_data);
int seg_sample = MODEL_SAMPLE_RATE/1000;
for(vector<int> segment:vad_segments)
{
frame = new AudioFrame();
int start = segment[0]*seg_sample;
int end = segment[1]*seg_sample;
frame->set_start(start);
frame->set_end(end);
frame->SetStart(start);
frame->SetEnd(end);
frame_queue.push(frame);
frame = NULL;
}

View File

@ -10,23 +10,23 @@ typedef struct
#ifdef _WIN32
#include <codecvt>
inline std::wstring string2wstring(const std::string& str, const std::string& locale)
inline std::wstring String2wstring(const std::string& str, const std::string& locale)
{
typedef std::codecvt_byname<wchar_t, char, std::mbstate_t> F;
std::wstring_convert<F> strCnv(new F(locale));
return strCnv.from_bytes(str);
}
inline std::wstring strToWstr(std::string str) {
inline std::wstring StrToWstr(std::string str) {
if (str.length() == 0)
return L"";
return string2wstring(str, "zh-CN");
return String2wstring(str, "zh-CN");
}
#endif
inline void getInputName(Ort::Session* session, string& inputName,int nIndex=0) {
inline void GetInputName(Ort::Session* session, string& inputName,int nIndex=0) {
size_t numInputNodes = session->GetInputCount();
if (numInputNodes > 0) {
Ort::AllocatorWithDefaultOptions allocator;
@ -38,7 +38,7 @@ inline void getInputName(Ort::Session* session, string& inputName,int nIndex=0)
}
}
inline void getOutputName(Ort::Session* session, string& outputName, int nIndex = 0) {
inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex = 0) {
size_t numOutputNodes = session->GetOutputCount();
if (numOutputNodes > 0) {
Ort::AllocatorWithDefaultOptions allocator;
@ -51,6 +51,6 @@ inline void getOutputName(Ort::Session* session, string& outputName, int nIndex
}
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
inline static size_t Argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}

View File

@ -7,8 +7,8 @@ CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
string strModelPath = PathAppend(sz_model_dir, PUNC_MODEL_FILE);
string strYamlPath = PathAppend(sz_model_dir, PUNC_YAML_FILE);
try{
#ifdef _WIN32
@ -24,12 +24,12 @@ CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
}
// read inputnames outputnamess
string strName;
getInputName(m_session.get(), strName);
GetInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
getInputName(m_session.get(), strName, 1);
GetInputName(m_session.get(), strName, 1);
m_strInputNames.push_back(strName);
getOutputName(m_session.get(), strName);
GetOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@ -77,12 +77,12 @@ string CTTransformer::AddPunc(const char* sz_input)
nLastCommaIndex = -1;
for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
{
if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX))
if (m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(PERIOD_INDEX) || m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(QUESTION_INDEX))
{
nSentEnd = nIndex;
break;
}
if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX))
if (nLastCommaIndex < 0 && m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(COMMA_INDEX))
{
nLastCommaIndex = nIndex;
}
@ -110,7 +110,7 @@ string CTTransformer::AddPunc(const char* sz_input)
if (Punction[i] != NOTPUNC_INDEX) // <20>»<EFBFBD><C2BB><EFBFBD>
{
WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i]));
WordWithPunc.push_back(m_tokenizer.Id2Punc(Punction[i]));
}
}
@ -120,17 +120,17 @@ string CTTransformer::AddPunc(const char* sz_input)
// last mini sentence
if(nCurBatch == nTotalBatch - 1)
{
if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX))
if (NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(DUN_INDEX))
{
NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
NewPuncOut.push_back(PERIOD_INDEX);
}
else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX))
else if (NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(QUESTION_INDEX))
{
NewSentenceOut = NewString;
NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
NewPuncOut = NewPunctuation;
NewPuncOut.push_back(PERIOD_INDEX);
}
@ -173,7 +173,7 @@ vector<int> CTTransformer::Infer(vector<int64_t> input_data)
for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
{
int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
int index = Argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
punction.push_back(index);
}
}

View File

@ -5,6 +5,12 @@
#include "precomp.h"
class FsmnVad {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* Deep-FSMN for Large Vocabulary Continuous Speech Recognition
* https://arxiv.org/abs/1803.05030
*/
public:
FsmnVad();
void Test();

View File

@ -19,15 +19,8 @@ using namespace std;
std::atomic<int> index(0);
std::mutex mtx;
void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
float* total_length, long* total_time, int core_id) {
// cpu_set_t cpuset;
// CPU_ZERO(&cpuset);
// CPU_SET(core_id, &cpuset);
// if(pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset) < 0){
// perror("pthread_setaffinity_np");
// }
struct timeval start, end;
long seconds = 0;
@ -37,7 +30,7 @@ void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
// warm up
for (size_t i = 0; i < 1; i++)
{
FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[0].c_str(), RASR_NONE, NULL);
FUNASR_RESULT result=FunASRRecogFile(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL);
}
while (true) {
@ -48,20 +41,20 @@ void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
}
gettimeofday(&start, NULL);
FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[i].c_str(), RASR_NONE, NULL);
FUNASR_RESULT result=FunASRRecogFile(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
n_total_time += taking_micros;
if(Result){
string msg = FunASRGetResult(Result, 0);
if(result){
string msg = FunASRGetResult(result, 0);
printf("Thread: %d Result: %s \n", this_thread::get_id(), msg.c_str());
float snippet_time = FunASRGetRetSnippetTime(Result);
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
FunASRFreeResult(Result);
FunASRFreeResult(result);
}else{
cout <<"No return data!";
}
@ -109,11 +102,11 @@ int main(int argc, char *argv[])
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
// thread num
int nThreadNum = 1;
nThreadNum = atoi(argv[4]);
int thread_num = 1;
thread_num = atoi(argv[4]);
FUNASR_HANDLE AsrHandle=FunASRInit(argv[1], 1, quantize);
if (!AsrHandle)
FUNASR_HANDLE asr_handle=FunASRInit(argv[1], 1, quantize);
if (!asr_handle)
{
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1);
@ -128,9 +121,9 @@ int main(int argc, char *argv[])
long total_time = 0;
std::vector<std::thread> threads;
for (int i = 0; i < nThreadNum; i++)
for (int i = 0; i < thread_num; i++)
{
threads.emplace_back(thread(runReg, AsrHandle, wav_list, &total_length, &total_time, i));
threads.emplace_back(thread(runReg, asr_handle, wav_list, &total_length, &total_time, i));
}
for (auto& thread : threads)
@ -143,6 +136,6 @@ int main(int argc, char *argv[])
printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
printf("speedup %05lf .\n", 1.0/((double)total_time/ (total_length*1000000)));
FunASRUninit(AsrHandle);
FunASRUninit(asr_handle);
return 0;
}

View File

@ -18,7 +18,7 @@ int main(int argc, char *argv[])
}
struct timeval start, end;
gettimeofday(&start, NULL);
int nThreadNum = 1;
int thread_num = 1;
// is quantize
bool quantize = false;
bool use_vad = false;
@ -26,9 +26,9 @@ int main(int argc, char *argv[])
istringstream(argv[3]) >> boolalpha >> quantize;
istringstream(argv[4]) >> boolalpha >> use_vad;
istringstream(argv[5]) >> boolalpha >> use_punc;
FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad, use_punc);
FUNASR_HANDLE asr_hanlde=FunASRInit(argv[1], thread_num, quantize, use_vad, use_punc);
if (!AsrHanlde)
if (!asr_hanlde)
{
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1);
@ -40,17 +40,17 @@ int main(int argc, char *argv[])
printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
gettimeofday(&start, NULL);
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
gettimeofday(&end, NULL);
float snippet_time = 0.0f;
if (Result)
if (result)
{
string msg = FunASRGetResult(Result, 0);
string msg = FunASRGetResult(result, 0);
setbuf(stdout, NULL);
printf("Result: %s \n", msg.c_str());
snippet_time = FunASRGetRetSnippetTime(Result);
FunASRFreeResult(Result);
snippet_time = FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
else
{
@ -63,7 +63,7 @@ int main(int argc, char *argv[])
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
FunASRUninit(AsrHanlde);
FunASRUninit(asr_hanlde);
return 0;
}

View File

@ -5,196 +5,196 @@ extern "C" {
#endif
// APIs for funasr
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad, bool use_punc)
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* sz_model_dir, int thread_num, bool quantize, bool use_vad, bool use_punc)
{
Model* mm = CreateModel(szModelDir, nThreadNum, quantize, use_vad, use_punc);
Model* mm = CreateModel(sz_model_dir, thread_num, quantize, use_vad, use_punc);
return mm;
}
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
Model* recog_obj = (Model*)handle;
if (!recog_obj)
return nullptr;
int32_t sampling_rate = -1;
Audio audio(1);
if (!audio.loadwav(szBuf, nLen, &sampling_rate))
if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
return nullptr;
if(use_vad){
audio.split(pRecogObj);
audio.Split(recog_obj);
}
float* buff;
int len;
int flag=0;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
if (fnCallback)
fnCallback(nStep, nTotal);
FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
p_result->msg += msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(use_punc){
string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
pResult->msg = punc_res;
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
return pResult;
return p_result;
}
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
Model* recog_obj = (Model*)handle;
if (!recog_obj)
return nullptr;
Audio audio(1);
if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
if(use_vad){
audio.split(pRecogObj);
audio.Split(recog_obj);
}
float* buff;
int len;
int flag = 0;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
if (fnCallback)
fnCallback(nStep, nTotal);
FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
p_result->msg += msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(use_punc){
string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
pResult->msg = punc_res;
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
return pResult;
return p_result;
}
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
Model* recog_obj = (Model*)handle;
if (!recog_obj)
return nullptr;
Audio audio(1);
if (!audio.loadpcmwav(szFileName, &sampling_rate))
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
return nullptr;
if(use_vad){
audio.split(pRecogObj);
audio.Split(recog_obj);
}
float* buff;
int len;
int flag = 0;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
if (fnCallback)
fnCallback(nStep, nTotal);
FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
p_result->msg += msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(use_punc){
string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
pResult->msg = punc_res;
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
return pResult;
return p_result;
}
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
Model* recog_obj = (Model*)handle;
if (!recog_obj)
return nullptr;
int32_t sampling_rate = -1;
Audio audio(1);
if(!audio.loadwav(szWavfile, &sampling_rate))
if(!audio.LoadWav(sz_wavfile, &sampling_rate))
return nullptr;
if(use_vad){
audio.split(pRecogObj);
audio.Split(recog_obj);
}
float* buff;
int len;
int flag = 0;
int nStep = 0;
int nTotal = audio.get_queue_size();
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
while (audio.fetch(buff, len, flag) > 0) {
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg+= msg;
nStep++;
if (fnCallback)
fnCallback(nStep, nTotal);
int n_step = 0;
int n_total = audio.GetQueueSize();
FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
p_result->msg+= msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(use_punc){
string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
pResult->msg = punc_res;
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
return pResult;
return p_result;
}
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
{
if (!Result)
if (!result)
return 0;
return 1;
}
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
{
if (!Result)
if (!result)
return 0.0f;
return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
return ((FUNASR_RECOG_RESULT*)result)->snippet_time;
}
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
{
FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
if(!pResult)
FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result;
if(!p_result)
return nullptr;
return pResult->msg.c_str();
return p_result->msg.c_str();
}
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
{
if (Result)
if (result)
{
delete (FUNASR_RECOG_RESULT*)Result;
delete (FUNASR_RECOG_RESULT*)result;
}
}
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
{
Model* pRecogObj = (Model*)handle;
Model* recog_obj = (Model*)handle;
if (!pRecogObj)
if (!recog_obj)
return;
delete pRecogObj;
delete recog_obj;
}
#ifdef __cplusplus

View File

@ -1,10 +1,10 @@
#include "precomp.h"
Model *CreateModel(const char *path, int nThread, bool quantize, bool use_vad, bool use_punc)
Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
{
Model *mm;
mm = new paraformer::ModelImp(path, nThread, quantize, use_vad, use_punc);
mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
return mm;
}

View File

@ -13,10 +13,10 @@ OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int
frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
}
void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
vector<float> waves, bool input_finished) {
input_finished_ = input_finished;
onlineFbank(vad_feats, waves);
OnlineFbank(vad_feats, waves);
// cache deal & online lfr,cmvn
if (vad_feats.size() > 0) {
if (!reserve_waveforms_.empty()) {
@ -53,7 +53,7 @@ void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
}
vad_feats = lfr_splice_cache_;
OnlineLfrCmvn(vad_feats);
reset_cache();
ResetCache();
}
}
@ -102,13 +102,13 @@ int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
return lfr_splice_frame_idxs;
}
void OnlineFeature::onlineFbank(vector<std::vector<float>> &vad_feats,
void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
vector<float> &waves) {
knf::OnlineFbank fbank(fbank_opts_);
// cache merge
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
int frame_number = compute_frame_num(waves.size(), frame_sample_length_, frame_shift_sample_length_);
int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
// Send the audio after the last frame shift position to the cache
input_cache_.clear();
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());

View File

@ -10,15 +10,15 @@ public:
OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
std::vector<std::vector<float>> cmvns_);
void extractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
private:
void onlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
static int compute_frame_num(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
if (frame_num >= 1 && sample_length >= frame_sample_length)
@ -27,7 +27,7 @@ private:
return 0;
}
void reset_cache() {
void ResetCache() {
reserve_waveforms_.clear();
input_cache_.clear();
lfr_splice_cache_.clear();

View File

@ -3,33 +3,33 @@
using namespace std;
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
string model_path;
string cmvn_path;
string config_path;
// VAD model
if(use_vad){
string vad_path = pathAppend(path, "vad_model.onnx");
string mvn_path = pathAppend(path, "vad.mvn");
vadHandle = make_unique<FsmnVad>();
vadHandle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
string vad_path = PathAppend(path, "vad_model.onnx");
string mvn_path = PathAppend(path, "vad.mvn");
vad_handle = make_unique<FsmnVad>();
vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
}
// PUNC model
if(use_punc){
puncHandle = make_unique<CTTransformer>(path, nNumThread);
punc_handle = make_unique<CTTransformer>(path, thread_num);
}
if(quantize)
{
model_path = pathAppend(path, "model_quant.onnx");
model_path = PathAppend(path, "model_quant.onnx");
}else{
model_path = pathAppend(path, "model.onnx");
model_path = PathAppend(path, "model.onnx");
}
cmvn_path = pathAppend(path, "am.mvn");
config_path = pathAppend(path, "config.yaml");
cmvn_path = PathAppend(path, "am.mvn");
config_path = PathAppend(path, "config.yaml");
// knf options
fbank_opts.frame_opts.dither = 0;
@ -42,28 +42,28 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad,
fbank_opts.mel_opts.debug_mel = false;
// fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
// sessionOptions.SetInterOpNumThreads(1);
sessionOptions.SetIntraOpNumThreads(nNumThread);
sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// session_options.SetInterOpNumThreads(1);
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// DisableCpuMemArena can improve performance
sessionOptions.DisableCpuMemArena();
session_options.DisableCpuMemArena();
#ifdef _WIN32
wstring wstrPath = strToWstr(model_path);
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#else
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#endif
string strName;
getInputName(m_session.get(), strName);
GetInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
getInputName(m_session.get(), strName,1);
GetInputName(m_session.get(), strName,1);
m_strInputNames.push_back(strName);
getOutputName(m_session.get(), strName);
GetOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
getOutputName(m_session.get(), strName,1);
GetOutputName(m_session.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@ -71,28 +71,28 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad,
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
vocab = new Vocab(config_path.c_str());
load_cmvn(cmvn_path.c_str());
LoadCmvn(cmvn_path.c_str());
}
ModelImp::~ModelImp()
Paraformer::~Paraformer()
{
if(vocab)
delete vocab;
}
void ModelImp::reset()
void Paraformer::Reset()
{
}
vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
return vadHandle->Infer(pcm_data);
vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
return vad_handle->Infer(pcm_data);
}
string ModelImp::AddPunc(const char* szInput){
return puncHandle->AddPunc(szInput);
string Paraformer::AddPunc(const char* sz_input){
return punc_handle->AddPunc(sz_input);
}
vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
knf::OnlineFbank fbank_(fbank_opts);
fbank_.AcceptWaveform(sample_rate, waves, len);
//fbank_->InputFinished();
@ -110,7 +110,7 @@ vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int le
return features;
}
void ModelImp::load_cmvn(const char *filename)
void Paraformer::LoadCmvn(const char *filename)
{
ifstream cmvn_stream(filename);
string line;
@ -143,21 +143,21 @@ void ModelImp::load_cmvn(const char *filename)
}
}
string ModelImp::greedy_search(float * in, int nLen )
string Paraformer::GreedySearch(float * in, int n_len )
{
vector<int> hyps;
int Tmax = nLen;
int Tmax = n_len;
for (int i = 0; i < Tmax; i++) {
int max_idx;
float max_val;
findmax(in + i * 8404, 8404, max_val, max_idx);
FindMax(in + i * 8404, 8404, max_val, max_idx);
hyps.push_back(max_idx);
}
return vocab->vector2stringV2(hyps);
return vocab->Vector2StringV2(hyps);
}
vector<float> ModelImp::ApplyLFR(const std::vector<float> &in)
vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
{
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
int32_t in_num_frames = in.size() / in_feat_dim;
@ -180,7 +180,7 @@ vector<float> ModelImp::ApplyLFR(const std::vector<float> &in)
return out;
}
void ModelImp::ApplyCMVN(std::vector<float> *v)
void Paraformer::ApplyCmvn(std::vector<float> *v)
{
int32_t dim = means_list.size();
int32_t num_frames = v->size() / dim;
@ -196,13 +196,13 @@ vector<float> ModelImp::ApplyLFR(const std::vector<float> &in)
}
}
string ModelImp::forward(float* din, int len, int flag)
string Paraformer::Forward(float* din, int len, int flag)
{
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
wav_feats = ApplyLFR(wav_feats);
ApplyCMVN(&wav_feats);
wav_feats = ApplyLfr(wav_feats);
ApplyCmvn(&wav_feats);
int32_t feat_dim = lfr_window_size*in_feat_dim;
int32_t num_frames = wav_feats.size() / feat_dim;
@ -238,7 +238,7 @@ string ModelImp::forward(float* din, int len, int flag)
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
result = greedy_search(floatData, *encoder_out_lens);
result = GreedySearch(floatData, *encoder_out_lens);
}
catch (std::exception const &e)
{
@ -248,14 +248,14 @@ string ModelImp::forward(float* din, int len, int flag)
return result;
}
string ModelImp::forward_chunk(float* din, int len, int flag)
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
printf("Not Imp!!!!!!\n");
return "Hello";
}
string ModelImp::rescoring()
string Paraformer::Rescoring()
{
printf("Not Imp!!!!!!\n");
return "Hello";

View File

@ -0,0 +1,53 @@
#pragma once
#ifndef PARAFORMER_MODELIMP_H
#define PARAFORMER_MODELIMP_H
#include "precomp.h"
namespace paraformer {
class Paraformer : public Model {
private:
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
std::unique_ptr<FsmnVad> vad_handle;
std::unique_ptr<CTTransformer> punc_handle;
Vocab* vocab;
vector<float> means_list;
vector<float> vars_list;
const float scale = 22.6274169979695;
int32_t lfr_window_size = 7;
int32_t lfr_window_shift = 6;
void LoadCmvn(const char *filename);
vector<float> ApplyLfr(const vector<float> &in);
void ApplyCmvn(vector<float> *v);
string GreedySearch( float* in, int n_len);
std::shared_ptr<Ort::Session> m_session;
Ort::Env env_;
Ort::SessionOptions session_options;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
public:
Paraformer(const char* path, int thread_num=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
~Paraformer();
void Reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
string ForwardChunk(float* din, int len, int flag);
string Forward(float* din, int len, int flag);
string Rescoring();
std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
string AddPunc(const char* sz_input);
};
} // namespace paraformer
#endif

View File

@ -1,54 +0,0 @@
#pragma once
#ifndef PARAFORMER_MODELIMP_H
#define PARAFORMER_MODELIMP_H
#include "precomp.h"
namespace paraformer {
class ModelImp : public Model {
private:
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
std::unique_ptr<FsmnVad> vadHandle;
std::unique_ptr<CTTransformer> puncHandle;
Vocab* vocab;
vector<float> means_list;
vector<float> vars_list;
const float scale = 22.6274169979695;
int32_t lfr_window_size = 7;
int32_t lfr_window_shift = 6;
void load_cmvn(const char *filename);
vector<float> ApplyLFR(const vector<float> &in);
void ApplyCMVN(vector<float> *v);
string greedy_search( float* in, int nLen);
std::shared_ptr<Ort::Session> m_session;
Ort::Env env_;
Ort::SessionOptions sessionOptions;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
public:
ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
~ModelImp();
void reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
string forward_chunk(float* din, int len, int flag);
string forward(float* din, int len, int flag);
string rescoring();
std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
string AddPunc(const char* szInput);
};
} // namespace paraformer
#endif

View File

@ -39,7 +39,7 @@ using namespace std;
#include "util.h"
#include "resample.h"
#include "model.h"
#include "paraformer_onnx.h"
#include "paraformer.h"
#include "libfunasrapi.h"
using namespace paraformer;

View File

@ -71,7 +71,7 @@ template <typename T> void Tensor<T>::alloc_buff()
{
buff_size = size[0] * size[1] * size[2] * size[3];
mem_size = buff_size;
buff = (T *)aligned_malloc(32, buff_size * sizeof(T));
buff = (T *)AlignedMalloc(32, buff_size * sizeof(T));
}
template <typename T> void Tensor<T>::free_buff()

View File

@ -1,26 +1,26 @@
#include "precomp.h"
CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
{
OpenYaml(szYmlFile);
OpenYaml(sz_yamlfile);
}
CTokenizer::CTokenizer():m_Ready(false)
CTokenizer::CTokenizer():m_ready(false)
{
}
void CTokenizer::read_yml(const YAML::Node& node)
void CTokenizer::ReadYaml(const YAML::Node& node)
{
if (node.IsMap())
{//<2F><>map<61><70>
for (auto it = node.begin(); it != node.end(); ++it)
{
read_yml(it->second);
ReadYaml(it->second);
}
}
if (node.IsSequence()) {//<2F><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
for (size_t i = 0; i < node.size(); ++i) {
read_yml(node[i]);
ReadYaml(node[i]);
}
}
if (node.IsScalar()) {//<2F>DZ<EFBFBD><C7B1><EFBFBD><EFBFBD><EFBFBD>
@ -28,9 +28,9 @@ void CTokenizer::read_yml(const YAML::Node& node)
}
}
bool CTokenizer::OpenYaml(const char* szYmlFile)
bool CTokenizer::OpenYaml(const char* sz_yamlfile)
{
YAML::Node m_Config = YAML::LoadFile(szYmlFile);
YAML::Node m_Config = YAML::LoadFile(sz_yamlfile);
if (m_Config.IsNull())
return false;
try
@ -42,8 +42,8 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
{
if (Tokens[i].IsScalar())
{
m_ID2Token.push_back(Tokens[i].as<string>());
m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
m_id2token.push_back(Tokens[i].as<string>());
m_token2id.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
}
}
}
@ -54,8 +54,8 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
{
if (Puncs[i].IsScalar())
{
m_ID2Punc.push_back(Puncs[i].as<string>());
m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
m_id2punc.push_back(Puncs[i].as<string>());
m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
}
}
}
@ -64,87 +64,87 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
std::cout << "read error!" << std::endl;
return false;
}
m_Ready = true;
return m_Ready;
m_ready = true;
return m_ready;
}
vector<string> CTokenizer::ID2String(vector<int> Input)
vector<string> CTokenizer::Id2String(vector<int> input)
{
vector<string> result;
for (auto& item : Input)
for (auto& item : input)
{
result.push_back(m_ID2Token[item]);
result.push_back(m_id2token[item]);
}
return result;
}
int CTokenizer::String2ID(string Input)
int CTokenizer::String2Id(string input)
{
int nID = 0; // <blank>
if (m_Token2ID.find(Input) != m_Token2ID.end())
nID=(m_Token2ID[Input]);
if (m_token2id.find(input) != m_token2id.end())
nID=(m_token2id[input]);
else
nID=(m_Token2ID[UNK_CHAR]);
nID=(m_token2id[UNK_CHAR]);
return nID;
}
vector<int> CTokenizer::String2IDs(vector<string> Input)
vector<int> CTokenizer::String2Ids(vector<string> input)
{
vector<int> result;
for (auto& item : Input)
for (auto& item : input)
{
transform(item.begin(), item.end(), item.begin(), ::tolower);
if (m_Token2ID.find(item) != m_Token2ID.end())
result.push_back(m_Token2ID[item]);
if (m_token2id.find(item) != m_token2id.end())
result.push_back(m_token2id[item]);
else
result.push_back(m_Token2ID[UNK_CHAR]);
result.push_back(m_token2id[UNK_CHAR]);
}
return result;
}
vector<string> CTokenizer::ID2Punc(vector<int> Input)
vector<string> CTokenizer::Id2Punc(vector<int> input)
{
vector<string> result;
for (auto& item : Input)
for (auto& item : input)
{
result.push_back(m_ID2Punc[item]);
result.push_back(m_id2punc[item]);
}
return result;
}
string CTokenizer::ID2Punc(int nPuncID)
string CTokenizer::Id2Punc(int n_punc_id)
{
return m_ID2Punc[nPuncID];
return m_id2punc[n_punc_id];
}
vector<int> CTokenizer::Punc2IDs(vector<string> Input)
vector<int> CTokenizer::Punc2Ids(vector<string> input)
{
vector<int> result;
for (auto& item : Input)
for (auto& item : input)
{
result.push_back(m_Punc2ID[item]);
result.push_back(m_punc2id[item]);
}
return result;
}
vector<string> CTokenizer::SplitChineseString(const string & strInfo)
vector<string> CTokenizer::SplitChineseString(const string & str_info)
{
vector<string> list;
int strSize = strInfo.size();
int strSize = str_info.size();
int i = 0;
while (i < strSize) {
int len = 1;
for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
for (int j = 0; j < 6 && (str_info[i] & (0x80 >> j)); j++) {
len = j + 1;
}
list.push_back(strInfo.substr(i, len));
list.push_back(str_info.substr(i, len));
i += len;
}
return list;
}
void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
{
if (str == "")
{
@ -161,10 +161,10 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
}
}
void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
{
vector<string> strList;
strSplit(strInfo,' ', strList);
StrSplit(str_info,' ', strList);
string current_eng,current_chinese;
for (auto& item : strList)
{
@ -178,7 +178,7 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
{
// for utf-8 chinese
auto chineseList = SplitChineseString(current_chinese);
strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
current_chinese = "";
}
current_eng += ch;
@ -187,7 +187,7 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
{
if (current_eng.size() > 0)
{
strOut.push_back(current_eng);
str_out.push_back(current_eng);
current_eng = "";
}
current_chinese += ch;
@ -196,13 +196,13 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
if (current_chinese.size() > 0)
{
auto chineseList = SplitChineseString(current_chinese);
strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
current_chinese = "";
}
if (current_eng.size() > 0)
{
strOut.push_back(current_eng);
str_out.push_back(current_eng);
}
}
IDOut= String2IDs(strOut);
id_out= String2Ids(str_out);
}

View File

@ -4,24 +4,24 @@
class CTokenizer {
private:
bool m_Ready = false;
vector<string> m_ID2Token,m_ID2Punc;
map<string, int> m_Token2ID,m_Punc2ID;
bool m_ready = false;
vector<string> m_id2token,m_id2punc;
map<string, int> m_token2id,m_punc2id;
public:
CTokenizer(const char* szYmlFile);
CTokenizer(const char* sz_yamlfile);
CTokenizer();
bool OpenYaml(const char* szYmlFile);
void read_yml(const YAML::Node& node);
vector<string> ID2String(vector<int> Input);
vector<int> String2IDs(vector<string> Input);
int String2ID(string Input);
vector<string> ID2Punc(vector<int> Input);
string ID2Punc(int nPuncID);
vector<int> Punc2IDs(vector<string> Input);
vector<string> SplitChineseString(const string& strInfo);
void strSplit(const string& str, const char split, vector<string>& res);
void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut);
bool OpenYaml(const char* sz_yamlfile);
void ReadYaml(const YAML::Node& node);
vector<string> Id2String(vector<int> input);
vector<int> String2Ids(vector<string> input);
int String2Id(string input);
vector<string> Id2Punc(vector<int> input);
string Id2Punc(int n_punc_id);
vector<int> Punc2Ids(vector<string> input);
vector<string> SplitChineseString(const string& str_info);
void StrSplit(const string& str, const char split, vector<string>& res);
void Tokenize(const char* str_info, vector<string>& str_out, vector<int>& id_out);
};

View File

@ -1,7 +1,7 @@
#include "precomp.h"
float *loadparams(const char *filename)
float *LoadParams(const char *filename)
{
FILE *fp;
@ -10,20 +10,20 @@ float *loadparams(const char *filename)
uint32_t nFileLen = ftell(fp);
fseek(fp, 0, SEEK_SET);
float *params_addr = (float *)aligned_malloc(32, nFileLen);
float *params_addr = (float *)AlignedMalloc(32, nFileLen);
int n = fread(params_addr, 1, nFileLen, fp);
fclose(fp);
return params_addr;
}
int val_align(int val, int align)
int ValAlign(int val, int align)
{
float tmp = ceil((float)val / (float)align) * (float)align;
return (int)tmp;
}
void disp_params(float *din, int size)
void DispParams(float *din, int size)
{
int i;
for (i = 0; i < size; i++) {
@ -39,7 +39,7 @@ void SaveDataFile(const char *filename, void *data, uint32_t len)
fclose(fp);
}
void basic_norm(Tensor<float> *&din, float norm)
void BasicNorm(Tensor<float> *&din, float norm)
{
int Tmax = din->size[2];
@ -59,7 +59,7 @@ void basic_norm(Tensor<float> *&din, float norm)
}
}
void findmax(float *din, int len, float &max_val, int &max_idx)
void FindMax(float *din, int len, float &max_val, int &max_idx)
{
int i;
max_val = -INFINITY;
@ -72,7 +72,7 @@ void findmax(float *din, int len, float &max_val, int &max_idx)
}
}
string pathAppend(const string &p1, const string &p2)
string PathAppend(const string &p1, const string &p2)
{
char sep = '/';
@ -89,7 +89,7 @@ string pathAppend(const string &p1, const string &p2)
return (p1 + p2);
}
void relu(Tensor<float> *din)
void Relu(Tensor<float> *din)
{
int i;
for (i = 0; i < din->buff_size; i++) {
@ -98,7 +98,7 @@ void relu(Tensor<float> *din)
}
}
void swish(Tensor<float> *din)
void Swish(Tensor<float> *din)
{
int i;
for (i = 0; i < din->buff_size; i++) {
@ -107,7 +107,7 @@ void swish(Tensor<float> *din)
}
}
void sigmoid(Tensor<float> *din)
void Sigmoid(Tensor<float> *din)
{
int i;
for (i = 0; i < din->buff_size; i++) {
@ -116,7 +116,7 @@ void sigmoid(Tensor<float> *din)
}
}
void doubleswish(Tensor<float> *din)
void DoubleSwish(Tensor<float> *din)
{
int i;
for (i = 0; i < din->buff_size; i++) {
@ -125,7 +125,7 @@ void doubleswish(Tensor<float> *din)
}
}
void softmax(float *din, int mask, int len)
void Softmax(float *din, int mask, int len)
{
float *tmp = (float *)malloc(mask * sizeof(float));
int i;
@ -149,7 +149,7 @@ void softmax(float *din, int mask, int len)
}
}
void log_softmax(float *din, int len)
void LogSoftmax(float *din, int len)
{
float *tmp = (float *)malloc(len * sizeof(float));
int i;
@ -164,7 +164,7 @@ void log_softmax(float *din, int len)
free(tmp);
}
void glu(Tensor<float> *din, Tensor<float> *dout)
void Glu(Tensor<float> *din, Tensor<float> *dout)
{
int mm = din->buff_size / 1024;
int i, j;

View File

@ -5,26 +5,26 @@
using namespace std;
extern float *loadparams(const char *filename);
extern float *LoadParams(const char *filename);
extern void SaveDataFile(const char *filename, void *data, uint32_t len);
extern void relu(Tensor<float> *din);
extern void swish(Tensor<float> *din);
extern void sigmoid(Tensor<float> *din);
extern void doubleswish(Tensor<float> *din);
extern void Relu(Tensor<float> *din);
extern void Swish(Tensor<float> *din);
extern void Sigmoid(Tensor<float> *din);
extern void DoubleSwish(Tensor<float> *din);
extern void softmax(float *din, int mask, int len);
extern void Softmax(float *din, int mask, int len);
extern void log_softmax(float *din, int len);
extern int val_align(int val, int align);
extern void disp_params(float *din, int size);
extern void LogSoftmax(float *din, int len);
extern int ValAlign(int val, int align);
extern void DispParams(float *din, int size);
extern void basic_norm(Tensor<float> *&din, float norm);
extern void BasicNorm(Tensor<float> *&din, float norm);
extern void findmax(float *din, int len, float &max_val, int &max_idx);
extern void FindMax(float *din, int len, float &max_val, int &max_idx);
extern void glu(Tensor<float> *din, Tensor<float> *dout);
extern void Glu(Tensor<float> *din, Tensor<float> *dout);
string pathAppend(const string &p1, const string &p2);
string PathAppend(const string &p1, const string &p2);
#endif

View File

@ -12,13 +12,13 @@ using namespace std;
Vocab::Vocab(const char *filename)
{
ifstream in(filename);
loadVocabFromYaml(filename);
LoadVocabFromYaml(filename);
}
Vocab::~Vocab()
{
}
void Vocab::loadVocabFromYaml(const char* filename){
void Vocab::LoadVocabFromYaml(const char* filename){
YAML::Node config;
try{
config = YAML::LoadFile(filename);
@ -26,72 +26,62 @@ void Vocab::loadVocabFromYaml(const char* filename){
printf("error loading file, yaml file error or not exist.\n");
exit(-1);
}
YAML::Node myList = config["token_list"];
for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
vocab.push_back(it->as<string>());
}
}
string Vocab::vector2string(vector<int> in)
string Vocab::Vector2String(vector<int> in)
{
int i;
stringstream ss;
for (auto it = in.begin(); it != in.end(); it++) {
ss << vocab[*it];
}
return ss.str();
}
int str2int(string str)
int Str2Int(string str)
{
const char *ch_array = str.c_str();
if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) ||
((ch_array[2] & 0xc0) != 0x80))
return 0;
int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) |
(ch_array[2] & 0x3f);
return val;
}
bool Vocab::isChinese(string ch)
bool Vocab::IsChinese(string ch)
{
if (ch.size() != 3) {
return false;
}
int unicode = str2int(ch);
int unicode = Str2Int(ch);
if (unicode >= 19968 && unicode <= 40959) {
return true;
}
return false;
}
string Vocab::vector2stringV2(vector<int> in)
string Vocab::Vector2StringV2(vector<int> in)
{
int i;
list<string> words;
int is_pre_english = false;
int pre_english_len = 0;
int is_combining = false;
string combine = "";
for (auto it = in.begin(); it != in.end(); it++) {
string word = vocab[*it];
// step1 space character skips
if (word == "<s>" || word == "</s>" || word == "<unk>")
continue;
// step2 combie phoneme to full word
{
int sub_word = !(word.find("@@") == string::npos);
// process word start and middle part
if (sub_word) {
combine += word.erase(word.length() - 2);
@ -109,15 +99,13 @@ string Vocab::vector2stringV2(vector<int> in)
// step3 process english word deal with space , turn abbreviation to upper case
{
// input word is chinese, not need process
if (isChinese(word)) {
if (IsChinese(word)) {
words.push_back(word);
is_pre_english = false;
}
// input word is english word
else {
// pre word is chinese
if (!is_pre_english) {
word[0] = word[0] - 32;
@ -125,10 +113,8 @@ string Vocab::vector2stringV2(vector<int> in)
pre_english_len = word.size();
}
// pre word is english word
else {
// single letter turn to upper case
if (word.size() == 1) {
word[0] = word[0] - 32;
@ -147,17 +133,11 @@ string Vocab::vector2stringV2(vector<int> in)
pre_english_len = word.size();
}
}
is_pre_english = true;
}
}
}
// for (auto it = words.begin(); it != words.end(); it++) {
// cout << *it << endl;
// }
stringstream ss;
for (auto it = words.begin(); it != words.end(); it++) {
ss << *it;
@ -166,7 +146,7 @@ string Vocab::vector2stringV2(vector<int> in)
return ss.str();
}
int Vocab::size()
int Vocab::Size()
{
return vocab.size();
}

View File

@ -10,16 +10,16 @@ using namespace std;
class Vocab {
private:
vector<string> vocab;
bool isChinese(string ch);
bool isEnglish(string ch);
void loadVocabFromYaml(const char* filename);
bool IsChinese(string ch);
bool IsEnglish(string ch);
void LoadVocabFromYaml(const char* filename);
public:
Vocab(const char *filename);
~Vocab();
int size();
string vector2string(vector<int> in);
string vector2stringV2(vector<int> in);
int Size();
string Vector2String(vector<int> in);
string Vector2StringV2(vector<int> in);
};
#endif