From 3372b13d24aceef7002cfa0fc8222b3085c15110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Fri, 2 Jun 2023 22:02:31 +0800 Subject: [PATCH] add fsmn-vad-online --- funasr/runtime/onnxruntime/CMakeLists.txt | 13 +- funasr/runtime/onnxruntime/bin/CMakeLists.txt | 16 ++ .../{src => bin}/funasr-onnx-offline-punc.cpp | 0 .../{src => bin}/funasr-onnx-offline-rtf.cpp | 0 .../{src => bin}/funasr-onnx-offline-vad.cpp | 2 +- .../{src => bin}/funasr-onnx-offline.cpp | 0 .../bin/funasr-onnx-online-vad.cpp | 193 +++++++++++++++++ funasr/runtime/onnxruntime/include/audio.h | 13 +- .../onnxruntime/include/funasrruntime.h | 13 +- .../runtime/onnxruntime/include/vad-model.h | 9 +- funasr/runtime/onnxruntime/src/CMakeLists.txt | 17 +- funasr/runtime/onnxruntime/src/audio.cpp | 78 ++++++- .../onnxruntime/src/fsmn-vad-online.cpp | 198 ++++++++++++++++++ .../runtime/onnxruntime/src/fsmn-vad-online.h | 88 ++++++++ funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 51 +++-- funasr/runtime/onnxruntime/src/fsmn-vad.h | 45 ++-- .../runtime/onnxruntime/src/funasrruntime.cpp | 18 +- .../onnxruntime/src/online-feature.cpp | 137 ------------ .../runtime/onnxruntime/src/online-feature.h | 58 ----- funasr/runtime/onnxruntime/src/paraformer.h | 4 +- funasr/runtime/onnxruntime/src/precomp.h | 3 +- funasr/runtime/onnxruntime/src/vad-model.cpp | 15 +- 22 files changed, 669 insertions(+), 302 deletions(-) create mode 100644 funasr/runtime/onnxruntime/bin/CMakeLists.txt rename funasr/runtime/onnxruntime/{src => bin}/funasr-onnx-offline-punc.cpp (100%) rename funasr/runtime/onnxruntime/{src => bin}/funasr-onnx-offline-rtf.cpp (100%) rename funasr/runtime/onnxruntime/{src => bin}/funasr-onnx-offline-vad.cpp (99%) rename funasr/runtime/onnxruntime/{src => bin}/funasr-onnx-offline.cpp (100%) create mode 100644 funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp create mode 100644 funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp create mode 100644 funasr/runtime/onnxruntime/src/fsmn-vad-online.h delete mode 100644 funasr/runtime/onnxruntime/src/online-feature.cpp delete mode 100644 funasr/runtime/onnxruntime/src/online-feature.h diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt index 9f6013f76..0847d1fc6 100644 --- a/funasr/runtime/onnxruntime/CMakeLists.txt +++ b/funasr/runtime/onnxruntime/CMakeLists.txt @@ -7,6 +7,8 @@ option(ENABLE_GLOG "Whether to build glog" ON) # set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + include(TestBigEndian) test_big_endian(BIG_ENDIAN) @@ -30,12 +32,13 @@ endif() include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank) include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include) -add_subdirectory(third_party/yaml-cpp) -add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc) -add_subdirectory(src) - if(ENABLE_GLOG) include_directories(${PROJECT_SOURCE_DIR}/third_party/glog) set(BUILD_TESTING OFF) add_subdirectory(third_party/glog) -endif() \ No newline at end of file +endif() + +add_subdirectory(third_party/yaml-cpp) +add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc) +add_subdirectory(src) +add_subdirectory(bin) diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt new file mode 100644 index 000000000..962da0bbc --- /dev/null +++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt @@ -0,0 +1,16 @@ +include_directories(${CMAKE_SOURCE_DIR}/include) + +add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp") +target_link_libraries(funasr-onnx-offline PUBLIC funasr) + +add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp") +target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr) + +add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp") +target_link_libraries(funasr-onnx-online-vad PUBLIC funasr) + +add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp") +target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) + +add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp") +target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp similarity index 100% rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp similarity index 100% rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp similarity index 99% rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp index 0f606c6d8..912630b82 100644 --- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp +++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp @@ -125,7 +125,7 @@ int main(int argc, char *argv[]) long taking_micros = 0; for(auto& wav_file : wav_list){ gettimeofday(&start, NULL); - FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000); + FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000); gettimeofday(&end, NULL); seconds = (end.tv_sec - start.tv_sec); taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp similarity index 100% rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp new file mode 100644 index 000000000..d9944a0f4 --- /dev/null +++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp @@ -0,0 +1,193 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#ifndef _WIN32 +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include "funasrruntime.h" +#include "tclap/CmdLine.h" +#include "com-define.h" +#include "audio.h" + +using namespace std; + +bool is_target_file(const std::string& filename, const std::string target) { + std::size_t pos = filename.find_last_of("."); + if (pos == std::string::npos) { + return false; + } + std::string extension = filename.substr(pos + 1); + return (extension == target); +} + +void GetValue(TCLAP::ValueArg& value_arg, string key, std::map& model_path) +{ + if (value_arg.isSet()){ + model_path.insert({key, value_arg.getValue()}); + LOG(INFO)<< key << " : " << value_arg.getValue(); + } +} + +void print_segs(vector>* vec) { + if((*vec).size() == 0){ + return; + } + string seg_out="["; + for (int i = 0; i < vec->size(); i++) { + vector inner_vec = (*vec)[i]; + if(inner_vec.size() == 0){ + continue; + } + seg_out += "["; + for (int j = 0; j < inner_vec.size(); j++) { + seg_out += to_string(inner_vec[j]); + if (j != inner_vec.size() - 1) { + seg_out += ","; + } + } + seg_out += "]"; + if (i != vec->size() - 1) { + seg_out += ","; + } + } + seg_out += "]"; + LOG(INFO)< model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); + TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); + + TCLAP::ValueArg wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); + + cmd.add(model_dir); + cmd.add(quantize); + cmd.add(wav_path); + cmd.parse(argc, argv); + + std::map model_path; + GetValue(model_dir, MODEL_DIR, model_path); + GetValue(quantize, QUANTIZE, model_path); + GetValue(wav_path, WAV_PATH, model_path); + + struct timeval start, end; + gettimeofday(&start, NULL); + int thread_num = 1; + FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num); + + if (!vad_hanlde) + { + LOG(ERROR) << "FunVad init failed"; + exit(-1); + } + + gettimeofday(&end, NULL); + long seconds = (end.tv_sec - start.tv_sec); + long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; + + // read wav_path + vector wav_list; + string wav_path_ = model_path.at(WAV_PATH); + if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){ + wav_list.emplace_back(wav_path_); + } + else if(is_target_file(wav_path_, "scp")){ + ifstream in(wav_path_); + if (!in.is_open()) { + LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ; + return 0; + } + string line; + while(getline(in, line)) + { + istringstream iss(line); + string column1, column2; + iss >> column1 >> column2; + wav_list.emplace_back(column2); + } + in.close(); + }else{ + LOG(ERROR)<<"Please check the wav extension!"; + exit(-1); + } + // init online features + FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde); + float snippet_time = 0.0f; + long taking_micros = 0; + for(auto& wav_file : wav_list){ + + int32_t sampling_rate_ = -1; + funasr::Audio audio(1); + if(is_target_file(wav_file.c_str(), "wav")){ + int32_t sampling_rate_ = -1; + if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){ + LOG(ERROR)<<"Failed to load "<< wav_file; + exit(-1); + } + }else if(is_target_file(wav_file.c_str(), "pcm")){ + if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){ + LOG(ERROR)<<"Failed to load "<< wav_file; + exit(-1); + } + }else{ + LOG(ERROR)<<"Wrong wav extension"; + exit(-1); + } + char* speech_buff = audio.GetSpeechChar(); + int buff_len = audio.GetSpeechLen()*2; + + int step = 3200; + bool is_final = false; + + for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) { + if (sample_offset + step >= buff_len - 1) { + step = buff_len - sample_offset; + is_final = true; + } else { + is_final = false; + } + gettimeofday(&start, NULL); + FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000); + gettimeofday(&end, NULL); + seconds = (end.tv_sec - start.tv_sec); + taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); + + if (result) + { + vector>* vad_segments = FsmnVadGetResult(result, 0); + print_segs(vad_segments); + snippet_time += FsmnVadGetRetSnippetTime(result); + FsmnVadFreeResult(result); + } + else + { + LOG(ERROR) << ("No return data!\n"); + } + } + } + + LOG(INFO) << "Audio length: " << (double)snippet_time << " s"; + LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; + LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000); + FsmnVadUninit(online_hanlde); + FsmnVadUninit(vad_hanlde); + return 0; +} + diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h index 1eabd3e7b..d2100a434 100644 --- a/funasr/runtime/onnxruntime/include/audio.h +++ b/funasr/runtime/onnxruntime/include/audio.h @@ -33,8 +33,9 @@ class AudioFrame { class Audio { private: - float *speech_data; - int16_t *speech_buff; + float *speech_data=nullptr; + int16_t *speech_buff=nullptr; + char* speech_char=nullptr; int speech_len; int speech_align_len; int offset; @@ -47,18 +48,22 @@ class Audio { Audio(int data_type, int size); ~Audio(); void Disp(); - bool LoadWav(const char* filename, int32_t* sampling_rate); void WavResample(int32_t sampling_rate, const float *waveform, int32_t n); bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate); + bool LoadWav(const char* filename, int32_t* sampling_rate); + bool LoadWav2Char(const char* filename, int32_t* sampling_rate); bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate); bool LoadPcmwav(const char* filename, int32_t* sampling_rate); + bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate); int FetchChunck(float *&dout, int len); int Fetch(float *&dout, int &len, int &flag); void Padding(); void Split(OfflineStream* offline_streamj); - void Split(VadModel* vad_obj, vector>& vad_segments); + void Split(VadModel* vad_obj, vector>& vad_segments, bool input_finished=true); float GetTimeLen(); int GetQueueSize() { return (int)frame_queue.size(); } + char* GetSpeechChar(){return speech_char;} + int GetSpeechLen(){return speech_len;} }; } // namespace funasr diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h index 5cfdb47d3..af430f795 100644 --- a/funasr/runtime/onnxruntime/include/funasrruntime.h +++ b/funasr/runtime/onnxruntime/include/funasrruntime.h @@ -46,12 +46,6 @@ typedef enum { FUNASR_MODEL_PARAFORMER = 3, }FUNASR_MODEL_TYPE; -typedef enum -{ - FSMN_VAD_OFFLINE=0, - FSMN_VAD_ONLINE = 1, -}FSMN_VAD_MODE; - typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step. // ASR @@ -68,11 +62,12 @@ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle); _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result); // VAD -_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE); +_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map& model_path, int thread_num); +_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle); // buffer -_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000); +_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000); // file, support wav & pcm -_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000); +_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000); _FUNASRAPI std::vector>* FsmnVadGetResult(FUNASR_RESULT result,int n_index); _FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result); diff --git a/funasr/runtime/onnxruntime/include/vad-model.h b/funasr/runtime/onnxruntime/include/vad-model.h index b1b1e9dbc..07f183327 100644 --- a/funasr/runtime/onnxruntime/include/vad-model.h +++ b/funasr/runtime/onnxruntime/include/vad-model.h @@ -12,14 +12,9 @@ class VadModel { virtual ~VadModel(){}; virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0; virtual std::vector> Infer(std::vector &waves, bool input_finished=true)=0; - virtual void ReadModel(const char* vad_model)=0; - virtual void LoadConfigFromYaml(const char* filename)=0; - virtual void FbankKaldi(float sample_rate, std::vector> &vad_feats, - std::vector &waves)=0; - virtual void LoadCmvn(const char *filename)=0; - virtual void InitCache()=0; }; -VadModel *CreateVadModel(std::map& model_path, int thread_num, int mode); +VadModel *CreateVadModel(std::map& model_path, int thread_num); +VadModel *CreateVadModel(void* fsmnvad_handle); } // namespace funasr #endif diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt index 341a16a7a..d083d8ea4 100644 --- a/funasr/runtime/onnxruntime/src/CMakeLists.txt +++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt @@ -1,11 +1,8 @@ file(GLOB files1 "*.cpp") -file(GLOB files2 "*.cc") +set(files ${files1}) -set(files ${files1} ${files2}) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) - -add_library(funasr ${files}) +add_library(funasr SHARED ${files}) if(WIN32) set(EXTRA_LIBS pthread yaml-cpp csrc glog) @@ -24,13 +21,3 @@ endif() include_directories(${CMAKE_SOURCE_DIR}/include) target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS}) - -add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp") -add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp") -add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp") -add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp") -target_link_libraries(funasr-onnx-offline PUBLIC funasr) -target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr) -target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) -target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) - diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp index 6d63d6757..23d001092 100644 --- a/funasr/runtime/onnxruntime/src/audio.cpp +++ b/funasr/runtime/onnxruntime/src/audio.cpp @@ -176,13 +176,13 @@ Audio::~Audio() { if (speech_buff != NULL) { free(speech_buff); - } - if (speech_data != NULL) { - free(speech_data); } + if (speech_char != NULL) { + free(speech_char); + } } void Audio::Disp() @@ -296,8 +296,47 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate) return false; } -bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate) +bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate) { + WaveHeader header; + if (speech_char != NULL) { + free(speech_char); + } + offset = 0; + std::ifstream is(filename, std::ifstream::binary); + is.read(reinterpret_cast(&header), sizeof(header)); + if(!is){ + LOG(ERROR) << "Failed to read " << filename; + return false; + } + if (!header.Validate()) { + return false; + } + header.SeekToDataChunk(is); + if (!is) { + return false; + } + if (!header.Validate()) { + return false; + } + header.SeekToDataChunk(is); + if (!is) { + return false; + } + + *sampling_rate = header.sample_rate; + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains two bytes, so it is divided by 2 here + speech_len = header.subchunk2_size / 2; + speech_char = (char *)malloc(header.subchunk2_size); + memset(speech_char, 0, header.subchunk2_size); + is.read(speech_char, header.subchunk2_size); + + return true; +} + +bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate) +{ WaveHeader header; if (speech_data != NULL) { free(speech_data); @@ -441,6 +480,33 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate) } +bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate) +{ + if (speech_char != NULL) { + free(speech_char); + } + offset = 0; + + FILE* fp; + fp = fopen(filename, "rb"); + if (fp == nullptr) + { + LOG(ERROR) << "Failed to read " << filename; + return false; + } + fseek(fp, 0, SEEK_END); + uint32_t n_file_len = ftell(fp); + fseek(fp, 0, SEEK_SET); + + speech_len = (n_file_len) / 2; + speech_char = (char *)malloc(n_file_len); + memset(speech_char, 0, n_file_len); + fread(speech_char, sizeof(int16_t), n_file_len/2, fp); + fclose(fp); + + return true; +} + int Audio::FetchChunck(float *&dout, int len) { if (offset >= speech_align_len) { @@ -541,7 +607,7 @@ void Audio::Split(OfflineStream* offline_stream) } -void Audio::Split(VadModel* vad_obj, vector>& vad_segments) +void Audio::Split(VadModel* vad_obj, vector>& vad_segments, bool input_finished) { AudioFrame *frame; @@ -552,7 +618,7 @@ void Audio::Split(VadModel* vad_obj, vector>& vad_segments) frame = NULL; std::vector pcm_data(speech_data, speech_data+sp_len); - vad_segments = vad_obj->Infer(pcm_data); + vad_segments = vad_obj->Infer(pcm_data, input_finished); } } // namespace funasr \ No newline at end of file diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp new file mode 100644 index 000000000..034691610 --- /dev/null +++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp @@ -0,0 +1,198 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#include +#include "precomp.h" + +namespace funasr { + +void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector> &vad_feats, + std::vector &waves) { + knf::OnlineFbank fbank(fbank_opts_); + // cache merge + waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end()); + int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_); + // Send the audio after the last frame shift position to the cache + input_cache_.clear(); + input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end()); + if (frame_number == 0) { + return; + } + // Delete audio that haven't undergone fbank processing + waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end()); + + std::vector buf(waves.size()); + for (int32_t i = 0; i != waves.size(); ++i) { + buf[i] = waves[i] * 32768; + } + fbank.AcceptWaveform(sample_rate, buf.data(), buf.size()); + // fbank.AcceptWaveform(sample_rate, &waves[0], waves.size()); + int32_t frames = fbank.NumFramesReady(); + for (int32_t i = 0; i != frames; ++i) { + const float *frame = fbank.GetFrame(i); + vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); + vad_feats.emplace_back(frame_vector); + } +} + +void FsmnVadOnline::ExtractFeats(float sample_rate, vector> &vad_feats, + vector &waves, bool input_finished) { + FbankKaldi(sample_rate, vad_feats, waves); + // cache deal & online lfr,cmvn + if (vad_feats.size() > 0) { + if (!reserve_waveforms_.empty()) { + waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end()); + } + if (lfr_splice_cache_.empty()) { + for (int i = 0; i < (lfr_m - 1) / 2; i++) { + lfr_splice_cache_.emplace_back(vad_feats[0]); + } + } + if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) { + vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end()); + int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1; + int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0; + int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished); + int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame; + reserve_waveforms_.clear(); + reserve_waveforms_.insert(reserve_waveforms_.begin(), + waves.begin() + reserve_frame_idx * frame_shift_sample_length_, + waves.begin() + frame_from_waves * frame_shift_sample_length_); + int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_; + waves.erase(waves.begin() + sample_length, waves.end()); + } else { + reserve_waveforms_.clear(); + reserve_waveforms_.insert(reserve_waveforms_.begin(), + waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end()); + lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end()); + } + } else { + if (input_finished) { + if (!reserve_waveforms_.empty()) { + waves = reserve_waveforms_; + } + vad_feats = lfr_splice_cache_; + OnlineLfrCmvn(vad_feats, input_finished); + } + } + if(input_finished){ + Reset(); + ResetCache(); + } +} + +int FsmnVadOnline::OnlineLfrCmvn(vector> &vad_feats, bool input_finished) { + vector> out_feats; + int T = vad_feats.size(); + int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n); + int lfr_splice_frame_idxs = T_lrf; + vector p; + for (int i = 0; i < T_lrf; i++) { + if (lfr_m <= T - i * lfr_n) { + for (int j = 0; j < lfr_m; j++) { + p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); + } + out_feats.emplace_back(p); + p.clear(); + } else { + if (input_finished) { + int num_padding = lfr_m - (T - i * lfr_n); + for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) { + p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); + } + for (int j = 0; j < num_padding; j++) { + p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end()); + } + out_feats.emplace_back(p); + } else { + lfr_splice_frame_idxs = i; + break; + } + } + } + lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n); + lfr_splice_cache_.clear(); + lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end()); + + // Apply cmvn + for (auto &out_feat: out_feats) { + for (int j = 0; j < means_list_.size(); j++) { + out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j]; + } + } + vad_feats = out_feats; + return lfr_splice_frame_idxs; +} + +std::vector> +FsmnVadOnline::Infer(std::vector &waves, bool input_finished) { + std::vector> vad_feats; + std::vector> vad_probs; + ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished); + fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished); + + std::vector> vad_segments; + vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_, + vad_speech_noise_thres_, vad_sample_rate_); + return vad_segments; +} + +void FsmnVadOnline::InitCache(){ + std::vector cache_feats(128 * 19 * 1, 0); + for (int i=0;i<4;i++){ + in_cache_.emplace_back(cache_feats); + } +}; + +void FsmnVadOnline::Reset(){ + in_cache_.clear(); + InitCache(); +}; + +void FsmnVadOnline::Test() { +} + +void FsmnVadOnline::InitOnline(std::shared_ptr &vad_session, + Ort::Env &env, + std::vector &vad_in_names, + std::vector &vad_out_names, + knf::FbankOptions &fbank_opts, + std::vector &means_list, + std::vector &vars_list, + int vad_sample_rate, + int vad_silence_duration, + int vad_max_len, + double vad_speech_noise_thres) { + vad_session_ = vad_session; + vad_in_names_ = vad_in_names; + vad_out_names_ = vad_out_names; + fbank_opts_ = fbank_opts; + means_list_ = means_list; + vars_list_ = vars_list; + vad_sample_rate_ = vad_sample_rate; + vad_silence_duration_ = vad_silence_duration; + vad_max_len_ = vad_max_len; + vad_speech_noise_thres_ = vad_speech_noise_thres; +} + +FsmnVadOnline::~FsmnVadOnline() { +} + +FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{ + InitCache(); + InitOnline(fsmnvad_handle_->vad_session_, + fsmnvad_handle_->env_, + fsmnvad_handle_->vad_in_names_, + fsmnvad_handle_->vad_out_names_, + fsmnvad_handle_->fbank_opts_, + fsmnvad_handle_->means_list_, + fsmnvad_handle_->vars_list_, + fsmnvad_handle_->vad_sample_rate_, + fsmnvad_handle_->vad_silence_duration_, + fsmnvad_handle_->vad_max_len_, + fsmnvad_handle_->vad_speech_noise_thres_); +} + +} // namespace funasr diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.h b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h new file mode 100644 index 000000000..4d429b669 --- /dev/null +++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h @@ -0,0 +1,88 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#pragma once +#include "precomp.h" + +namespace funasr { +class FsmnVadOnline : public VadModel { +/** + * Author: Speech Lab of DAMO Academy, Alibaba Group + * Deep-FSMN for Large Vocabulary Continuous Speech Recognition + * https://arxiv.org/abs/1803.05030 +*/ + +public: + explicit FsmnVadOnline(FsmnVad* fsmnvad_handle); + ~FsmnVadOnline(); + void Test(); + std::vector> Infer(std::vector &waves, bool input_finished); + void ExtractFeats(float sample_rate, vector> &vad_feats, vector &waves, bool input_finished); + void Reset(); + +private: + E2EVadModel vad_scorer = E2EVadModel(); + // std::unique_ptr fsmnvad_handle_; + FsmnVad* fsmnvad_handle_ = nullptr; + + void FbankKaldi(float sample_rate, std::vector> &vad_feats, + std::vector &waves); + int OnlineLfrCmvn(vector> &vad_feats, bool input_finished); + void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){} + void InitCache(); + void InitOnline(std::shared_ptr &vad_session, + Ort::Env &env, + std::vector &vad_in_names, + std::vector &vad_out_names, + knf::FbankOptions &fbank_opts, + std::vector &means_list, + std::vector &vars_list, + int vad_sample_rate, + int vad_silence_duration, + int vad_max_len, + double vad_speech_noise_thres); + + static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) { + int frame_num = static_cast((sample_length - frame_sample_length) / frame_shift_sample_length + 1); + if (frame_num >= 1 && sample_length >= frame_sample_length) + return frame_num; + else + return 0; + } + void ResetCache() { + reserve_waveforms_.clear(); + input_cache_.clear(); + lfr_splice_cache_.clear(); + } + + // from fsmnvad_handle_ + std::shared_ptr vad_session_ = nullptr; + Ort::Env env_; + Ort::SessionOptions session_options_; + std::vector vad_in_names_; + std::vector vad_out_names_; + knf::FbankOptions fbank_opts_; + std::vector means_list_; + std::vector vars_list_; + + std::vector> in_cache_; + // The reserved waveforms by fbank + std::vector reserve_waveforms_; + // waveforms reserved after last shift position + std::vector input_cache_; + // lfr reserved cache + std::vector> lfr_splice_cache_; + + int vad_sample_rate_ = MODEL_SAMPLE_RATE; + int vad_silence_duration_ = VAD_SILENCE_DURATION; + int vad_max_len_ = VAD_MAX_LEN; + double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES; + int lfr_m = VAD_LFR_M; + int lfr_n = VAD_LFR_N; + int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;; + int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10; +}; + +} // namespace funasr diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp index 516dc8822..697828b9f 100644 --- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp +++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp @@ -37,14 +37,14 @@ void FsmnVad::LoadConfigFromYaml(const char* filename){ this->vad_max_len_ = post_conf["max_single_segment_time"].as(); this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as(); - fbank_opts.frame_opts.dither = frontend_conf["dither"].as(); - fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as(); - fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_; - fbank_opts.frame_opts.window_type = frontend_conf["window"].as(); - fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as(); - fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as(); - fbank_opts.energy_floor = 0; - fbank_opts.mel_opts.debug_mel = false; + fbank_opts_.frame_opts.dither = frontend_conf["dither"].as(); + fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as(); + fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_; + fbank_opts_.frame_opts.window_type = frontend_conf["window"].as(); + fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as(); + fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as(); + fbank_opts_.energy_floor = 0; + fbank_opts_.mel_opts.debug_mel = false; }catch(exception const &e){ LOG(ERROR) << "Error when load argument from vad config YAML."; exit(-1); @@ -55,6 +55,7 @@ void FsmnVad::ReadModel(const char* vad_model) { try { vad_session_ = std::make_shared( env_, vad_model, session_options_); + LOG(INFO) << "Successfully load model from " << vad_model; } catch (std::exception const &e) { LOG(ERROR) << "Error when load vad onnx model: " << e.what(); exit(0); @@ -109,7 +110,9 @@ void FsmnVad::GetInputOutputInfo( void FsmnVad::Forward( const std::vector> &chunk_feats, - std::vector> *out_prob) { + std::vector> *out_prob, + std::vector> *in_cache, + bool is_final) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -132,9 +135,9 @@ void FsmnVad::Forward( // 4 caches // cache node {batch,128,19,1} const int64_t cache_feats_shape[4] = {1, 128, 19, 1}; - for (int i = 0; i < in_cache_.size(); i++) { + for (int i = 0; i < in_cache->size(); i++) { vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor( - memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4))); + memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4))); } // 4. Onnx infer @@ -162,15 +165,17 @@ void FsmnVad::Forward( } // get 4 caches outputs,each size is 128*19 - // for (int i = 1; i < 5; i++) { - // float* data = vad_ort_outputs[i].GetTensorMutableData(); - // memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19); - // } + if(!is_final){ + for (int i = 1; i < 5; i++) { + float* data = vad_ort_outputs[i].GetTensorMutableData(); + memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19); + } + } } void FsmnVad::FbankKaldi(float sample_rate, std::vector> &vad_feats, std::vector &waves) { - knf::OnlineFbank fbank(fbank_opts); + knf::OnlineFbank fbank(fbank_opts_); std::vector buf(waves.size()); for (int32_t i = 0; i != waves.size(); ++i) { @@ -180,7 +185,7 @@ void FsmnVad::FbankKaldi(float sample_rate, std::vector> &vad int32_t frames = fbank.NumFramesReady(); for (int32_t i = 0; i != frames; ++i) { const float *frame = fbank.GetFrame(i); - std::vector frame_vector(frame, frame + fbank_opts.mel_opts.num_bins); + std::vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); vad_feats.emplace_back(frame_vector); } } @@ -205,7 +210,7 @@ void FsmnVad::LoadCmvn(const char *filename) vector means_lines{istream_iterator{means_lines_stream}, istream_iterator{}}; if (means_lines[0] == "") { for (int j = 3; j < means_lines.size() - 1; j++) { - means_list.push_back(stof(means_lines[j])); + means_list_.push_back(stof(means_lines[j])); } continue; } @@ -216,8 +221,8 @@ void FsmnVad::LoadCmvn(const char *filename) vector vars_lines{istream_iterator{vars_lines_stream}, istream_iterator{}}; if (vars_lines[0] == "") { for (int j = 3; j < vars_lines.size() - 1; j++) { - // vars_list.push_back(stof(vars_lines[j])*scale); - vars_list.push_back(stof(vars_lines[j])); + // vars_list_.push_back(stof(vars_lines[j])*scale); + vars_list_.push_back(stof(vars_lines[j])); } continue; } @@ -263,8 +268,8 @@ void FsmnVad::LfrCmvn(std::vector> &vad_feats) { } // Apply cmvn for (auto &out_feat: out_feats) { - for (int j = 0; j < means_list.size(); j++) { - out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j]; + for (int j = 0; j < means_list_.size(); j++) { + out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j]; } } vad_feats = out_feats; @@ -276,7 +281,7 @@ FsmnVad::Infer(std::vector &waves, bool input_finished) { std::vector> vad_probs; FbankKaldi(vad_sample_rate_, vad_feats, waves); LfrCmvn(vad_feats); - Forward(vad_feats, &vad_probs); + Forward(vad_feats, &vad_probs, &in_cache_, input_finished); E2EVadModel vad_scorer = E2EVadModel(); std::vector> vad_segments; diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h index a8ec4ce90..adceb1fab 100644 --- a/funasr/runtime/onnxruntime/src/fsmn-vad.h +++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h @@ -22,7 +22,30 @@ public: void Test(); void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num); std::vector> Infer(std::vector &waves, bool input_finished=true); + void Forward( + const std::vector> &chunk_feats, + std::vector> *out_prob, + std::vector> *in_cache, + bool is_final); void Reset(); + + std::shared_ptr vad_session_ = nullptr; + Ort::Env env_; + Ort::SessionOptions session_options_; + std::vector vad_in_names_; + std::vector vad_out_names_; + std::vector> in_cache_; + + knf::FbankOptions fbank_opts_; + std::vector means_list_; + std::vector vars_list_; + + int vad_sample_rate_ = MODEL_SAMPLE_RATE; + int vad_silence_duration_ = VAD_SILENCE_DURATION; + int vad_max_len_ = VAD_MAX_LEN; + double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES; + int lfr_m = VAD_LFR_M; + int lfr_n = VAD_LFR_N; private: @@ -37,31 +60,9 @@ private: std::vector &waves); void LfrCmvn(std::vector> &vad_feats); - - void Forward( - const std::vector> &chunk_feats, - std::vector> *out_prob); - void LoadCmvn(const char *filename); void InitCache(); - std::shared_ptr vad_session_ = nullptr; - Ort::Env env_; - Ort::SessionOptions session_options_; - std::vector vad_in_names_; - std::vector vad_out_names_; - std::vector> in_cache_; - - knf::FbankOptions fbank_opts; - std::vector means_list; - std::vector vars_list; - - int vad_sample_rate_ = MODEL_SAMPLE_RATE; - int vad_silence_duration_ = VAD_SILENCE_DURATION; - int vad_max_len_ = VAD_MAX_LEN; - double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES; - int lfr_m = VAD_LFR_M; - int lfr_n = VAD_LFR_N; }; } // namespace funasr diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp index adef5049e..f504b39f0 100644 --- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp +++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp @@ -11,9 +11,15 @@ extern "C" { return mm; } - _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map& model_path, int thread_num, FSMN_VAD_MODE mode) + _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map& model_path, int thread_num) { - funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode); + funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num); + return mm; + } + + _FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle) + { + funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle); return mm; } @@ -96,7 +102,7 @@ extern "C" { } // APIs for VAD Infer - _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate) + _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate) { funasr::VadModel* vad_obj = (funasr::VadModel*)handle; if (!vad_obj) @@ -110,13 +116,13 @@ extern "C" { p_result->snippet_time = audio.GetTimeLen(); vector> vad_segments; - audio.Split(vad_obj, vad_segments); + audio.Split(vad_obj, vad_segments, input_finished); p_result->segments = new vector>(vad_segments); return p_result; } - _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate) + _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate) { funasr::VadModel* vad_obj = (funasr::VadModel*)handle; if (!vad_obj) @@ -139,7 +145,7 @@ extern "C" { p_result->snippet_time = audio.GetTimeLen(); vector> vad_segments; - audio.Split(vad_obj, vad_segments); + audio.Split(vad_obj, vad_segments, true); p_result->segments = new vector>(vad_segments); return p_result; diff --git a/funasr/runtime/onnxruntime/src/online-feature.cpp b/funasr/runtime/onnxruntime/src/online-feature.cpp deleted file mode 100644 index a21589cf3..000000000 --- a/funasr/runtime/onnxruntime/src/online-feature.cpp +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. - * MIT License (https://opensource.org/licenses/MIT) - * Contributed by zhuzizyf(China Telecom). -*/ - -#include "online-feature.h" -#include - -namespace funasr { -OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n, - std::vector> cmvns) - : sample_rate_(sample_rate), - fbank_opts_(std::move(fbank_opts)), - lfr_m_(lfr_m), - lfr_n_(lfr_n), - cmvns_(std::move(cmvns)) { - frame_sample_length_ = sample_rate_ / 1000 * 25;; - frame_shift_sample_length_ = sample_rate_ / 1000 * 10; -} - -void OnlineFeature::ExtractFeats(vector> &vad_feats, - vector waves, bool input_finished) { - input_finished_ = input_finished; - OnlineFbank(vad_feats, waves); - // cache deal & online lfr,cmvn - if (vad_feats.size() > 0) { - if (!reserve_waveforms_.empty()) { - waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end()); - } - if (lfr_splice_cache_.empty()) { - for (int i = 0; i < (lfr_m_ - 1) / 2; i++) { - lfr_splice_cache_.emplace_back(vad_feats[0]); - } - } - if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) { - vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end()); - int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1; - int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0; - int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats); - int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame; - reserve_waveforms_.clear(); - reserve_waveforms_.insert(reserve_waveforms_.begin(), - waves.begin() + reserve_frame_idx * frame_shift_sample_length_, - waves.begin() + frame_from_waves * frame_shift_sample_length_); - int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_; - waves.erase(waves.begin() + sample_length, waves.end()); - } else { - reserve_waveforms_.clear(); - reserve_waveforms_.insert(reserve_waveforms_.begin(), - waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end()); - lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end()); - } - - } else { - if (input_finished_) { - if (!reserve_waveforms_.empty()) { - waves = reserve_waveforms_; - } - vad_feats = lfr_splice_cache_; - OnlineLfrCmvn(vad_feats); - ResetCache(); - } - } - -} - -int OnlineFeature::OnlineLfrCmvn(vector> &vad_feats) { - vector> out_feats; - int T = vad_feats.size(); - int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_); - int lfr_splice_frame_idxs = T_lrf; - vector p; - for (int i = 0; i < T_lrf; i++) { - if (lfr_m_ <= T - i * lfr_n_) { - for (int j = 0; j < lfr_m_; j++) { - p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end()); - } - out_feats.emplace_back(p); - p.clear(); - } else { - if (input_finished_) { - int num_padding = lfr_m_ - (T - i * lfr_n_); - for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) { - p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end()); - } - for (int j = 0; j < num_padding; j++) { - p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end()); - } - out_feats.emplace_back(p); - } else { - lfr_splice_frame_idxs = i; - break; - } - } - } - lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_); - lfr_splice_cache_.clear(); - lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end()); - - // Apply cmvn - for (auto &out_feat: out_feats) { - for (int j = 0; j < cmvns_[0].size(); j++) { - out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j]; - } - } - vad_feats = out_feats; - return lfr_splice_frame_idxs; -} - -void OnlineFeature::OnlineFbank(vector> &vad_feats, - vector &waves) { - - knf::OnlineFbank fbank(fbank_opts_); - // cache merge - waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end()); - int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_); - // Send the audio after the last frame shift position to the cache - input_cache_.clear(); - input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end()); - if (frame_number == 0) { - return; - } - // Delete audio that haven't undergone fbank processing - waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end()); - - fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size()); - int32_t frames = fbank.NumFramesReady(); - for (int32_t i = 0; i != frames; ++i) { - const float *frame = fbank.GetFrame(i); - vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); - vad_feats.emplace_back(frame_vector); - } - -} - -} // namespace funasr \ No newline at end of file diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h deleted file mode 100644 index 16e6e4bea..000000000 --- a/funasr/runtime/onnxruntime/src/online-feature.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. - * MIT License (https://opensource.org/licenses/MIT) - * Contributed by zhuzizyf(China Telecom). -*/ -#pragma once -#include -#include "precomp.h" - -using namespace std; -namespace funasr { -class OnlineFeature { - -public: - OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_, - std::vector> cmvns_); - - void ExtractFeats(vector> &vad_feats, vector waves, bool input_finished); - -private: - void OnlineFbank(vector> &vad_feats, vector &waves); - int OnlineLfrCmvn(vector> &vad_feats); - - static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) { - int frame_num = static_cast((sample_length - frame_sample_length) / frame_shift_sample_length + 1); - if (frame_num >= 1 && sample_length >= frame_sample_length) - return frame_num; - else - return 0; - } - - void ResetCache() { - reserve_waveforms_.clear(); - input_cache_.clear(); - lfr_splice_cache_.clear(); - input_finished_ = false; - - } - - knf::FbankOptions fbank_opts_; - // The reserved waveforms by fbank - std::vector reserve_waveforms_; - // waveforms reserved after last shift position - std::vector input_cache_; - // lfr reserved cache - std::vector> lfr_splice_cache_; - std::vector> cmvns_; - - int sample_rate_ = 16000; - int frame_sample_length_ = sample_rate_ / 1000 * 25;; - int frame_shift_sample_length_ = sample_rate_ / 1000 * 10; - int lfr_m_; - int lfr_n_; - bool input_finished_ = false; - -}; - -} // namespace funasr diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h index 533c16fdc..9df0977e5 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.h +++ b/funasr/runtime/onnxruntime/src/paraformer.h @@ -18,7 +18,7 @@ namespace funasr { //std::unique_ptr fbank_; knf::FbankOptions fbank_opts; - Vocab* vocab; + Vocab* vocab = nullptr; vector means_list; vector vars_list; const float scale = 22.6274169979695; @@ -30,7 +30,7 @@ namespace funasr { void ApplyCmvn(vector *v); string GreedySearch( float* in, int n_len, int64_t token_nums); - std::shared_ptr m_session; + std::shared_ptr m_session = nullptr; Ort::Env env_; Ort::SessionOptions session_options; diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h index e607dbff1..838dddc87 100644 --- a/funasr/runtime/onnxruntime/src/precomp.h +++ b/funasr/runtime/onnxruntime/src/precomp.h @@ -36,8 +36,9 @@ using namespace std; #include "offline-stream.h" #include "tokenizer.h" #include "ct-transformer.h" -#include "fsmn-vad.h" #include "e2e-vad.h" +#include "fsmn-vad.h" +#include "fsmn-vad-online.h" #include "vocab.h" #include "audio.h" #include "tensor.h" diff --git a/funasr/runtime/onnxruntime/src/vad-model.cpp b/funasr/runtime/onnxruntime/src/vad-model.cpp index 336758f87..c164c3ec1 100644 --- a/funasr/runtime/onnxruntime/src/vad-model.cpp +++ b/funasr/runtime/onnxruntime/src/vad-model.cpp @@ -1,14 +1,10 @@ #include "precomp.h" namespace funasr { -VadModel *CreateVadModel(std::map& model_path, int thread_num, int mode) +VadModel *CreateVadModel(std::map& model_path, int thread_num) { VadModel *mm; - if(mode == FSMN_VAD_OFFLINE){ - mm = new FsmnVad(); - }else{ - LOG(ERROR)<<"Online fsmn vad not imp!"; - } + mm = new FsmnVad(); string vad_model_path; string vad_cmvn_path; @@ -25,4 +21,11 @@ VadModel *CreateVadModel(std::map& model_path, int thr return mm; } +VadModel *CreateVadModel(void* fsmnvad_handle) +{ + VadModel *mm; + mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle); + return mm; +} + } // namespace funasr \ No newline at end of file