mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add fsmn-vad-online
This commit is contained in:
parent
dbffb47415
commit
3372b13d24
@ -7,6 +7,8 @@ option(ENABLE_GLOG "Whether to build glog" ON)
|
||||
# set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
|
||||
include(TestBigEndian)
|
||||
test_big_endian(BIG_ENDIAN)
|
||||
@ -30,12 +32,13 @@ endif()
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
|
||||
|
||||
add_subdirectory(third_party/yaml-cpp)
|
||||
add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
|
||||
add_subdirectory(src)
|
||||
|
||||
if(ENABLE_GLOG)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
|
||||
set(BUILD_TESTING OFF)
|
||||
add_subdirectory(third_party/glog)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(third_party/yaml-cpp)
|
||||
add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(bin)
|
||||
|
||||
16
funasr/runtime/onnxruntime/bin/CMakeLists.txt
Normal file
16
funasr/runtime/onnxruntime/bin/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
|
||||
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
|
||||
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
|
||||
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
|
||||
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
|
||||
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
|
||||
@ -125,7 +125,7 @@ int main(int argc, char *argv[])
|
||||
long taking_micros = 0;
|
||||
for(auto& wav_file : wav_list){
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
|
||||
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
193
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
Normal file
193
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
Normal file
@ -0,0 +1,193 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <glog/logging.h>
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "audio.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
void print_segs(vector<vector<int>>* vec) {
|
||||
if((*vec).size() == 0){
|
||||
return;
|
||||
}
|
||||
string seg_out="[";
|
||||
for (int i = 0; i < vec->size(); i++) {
|
||||
vector<int> inner_vec = (*vec)[i];
|
||||
if(inner_vec.size() == 0){
|
||||
continue;
|
||||
}
|
||||
seg_out += "[";
|
||||
for (int j = 0; j < inner_vec.size(); j++) {
|
||||
seg_out += to_string(inner_vec[j]);
|
||||
if (j != inner_vec.size() - 1) {
|
||||
seg_out += ",";
|
||||
}
|
||||
}
|
||||
seg_out += "]";
|
||||
if (i != vec->size() - 1) {
|
||||
seg_out += ",";
|
||||
}
|
||||
}
|
||||
seg_out += "]";
|
||||
LOG(INFO)<<seg_out;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
|
||||
|
||||
if (!vad_hanlde)
|
||||
{
|
||||
LOG(ERROR) << "FunVad init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path
|
||||
vector<string> wav_list;
|
||||
string wav_path_ = model_path.at(WAV_PATH);
|
||||
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
|
||||
wav_list.emplace_back(wav_path_);
|
||||
}
|
||||
else if(is_target_file(wav_path_, "scp")){
|
||||
ifstream in(wav_path_);
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
LOG(ERROR)<<"Please check the wav extension!";
|
||||
exit(-1);
|
||||
}
|
||||
// init online features
|
||||
FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
|
||||
float snippet_time = 0.0f;
|
||||
long taking_micros = 0;
|
||||
for(auto& wav_file : wav_list){
|
||||
|
||||
int32_t sampling_rate_ = -1;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
int32_t sampling_rate_ = -1;
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_file.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
LOG(ERROR)<<"Wrong wav extension";
|
||||
exit(-1);
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 3200;
|
||||
bool is_final = false;
|
||||
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
if (result)
|
||||
{
|
||||
vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
|
||||
print_segs(vad_segments);
|
||||
snippet_time += FsmnVadGetRetSnippetTime(result);
|
||||
FsmnVadFreeResult(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FsmnVadUninit(online_hanlde);
|
||||
FsmnVadUninit(vad_hanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -33,8 +33,9 @@ class AudioFrame {
|
||||
|
||||
class Audio {
|
||||
private:
|
||||
float *speech_data;
|
||||
int16_t *speech_buff;
|
||||
float *speech_data=nullptr;
|
||||
int16_t *speech_buff=nullptr;
|
||||
char* speech_char=nullptr;
|
||||
int speech_len;
|
||||
int speech_align_len;
|
||||
int offset;
|
||||
@ -47,18 +48,22 @@ class Audio {
|
||||
Audio(int data_type, int size);
|
||||
~Audio();
|
||||
void Disp();
|
||||
bool LoadWav(const char* filename, int32_t* sampling_rate);
|
||||
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
|
||||
bool LoadWav(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
|
||||
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
|
||||
int FetchChunck(float *&dout, int len);
|
||||
int Fetch(float *&dout, int &len, int &flag);
|
||||
void Padding();
|
||||
void Split(OfflineStream* offline_streamj);
|
||||
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
|
||||
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
|
||||
float GetTimeLen();
|
||||
int GetQueueSize() { return (int)frame_queue.size(); }
|
||||
char* GetSpeechChar(){return speech_char;}
|
||||
int GetSpeechLen(){return speech_len;}
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -46,12 +46,6 @@ typedef enum {
|
||||
FUNASR_MODEL_PARAFORMER = 3,
|
||||
}FUNASR_MODEL_TYPE;
|
||||
|
||||
typedef enum
|
||||
{
|
||||
FSMN_VAD_OFFLINE=0,
|
||||
FSMN_VAD_ONLINE = 1,
|
||||
}FSMN_VAD_MODE;
|
||||
|
||||
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
|
||||
|
||||
// ASR
|
||||
@ -68,11 +62,12 @@ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
|
||||
|
||||
// VAD
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE);
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
|
||||
// buffer
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
|
||||
// file, support wav & pcm
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
|
||||
|
||||
_FUNASRAPI std::vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result);
|
||||
|
||||
@ -12,14 +12,9 @@ class VadModel {
|
||||
virtual ~VadModel(){};
|
||||
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
|
||||
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
|
||||
virtual void ReadModel(const char* vad_model)=0;
|
||||
virtual void LoadConfigFromYaml(const char* filename)=0;
|
||||
virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
|
||||
std::vector<float> &waves)=0;
|
||||
virtual void LoadCmvn(const char *filename)=0;
|
||||
virtual void InitCache()=0;
|
||||
};
|
||||
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode);
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
VadModel *CreateVadModel(void* fsmnvad_handle);
|
||||
} // namespace funasr
|
||||
#endif
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
|
||||
file(GLOB files1 "*.cpp")
|
||||
file(GLOB files2 "*.cc")
|
||||
set(files ${files1})
|
||||
|
||||
set(files ${files1} ${files2})
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
add_library(funasr ${files})
|
||||
add_library(funasr SHARED ${files})
|
||||
|
||||
if(WIN32)
|
||||
set(EXTRA_LIBS pthread yaml-cpp csrc glog)
|
||||
@ -24,13 +21,3 @@ endif()
|
||||
|
||||
include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
|
||||
|
||||
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
|
||||
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
|
||||
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
|
||||
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
|
||||
|
||||
|
||||
@ -176,13 +176,13 @@ Audio::~Audio()
|
||||
{
|
||||
if (speech_buff != NULL) {
|
||||
free(speech_buff);
|
||||
|
||||
}
|
||||
|
||||
if (speech_data != NULL) {
|
||||
|
||||
free(speech_data);
|
||||
}
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
}
|
||||
|
||||
void Audio::Disp()
|
||||
@ -296,8 +296,47 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
|
||||
bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate)
|
||||
{
|
||||
WaveHeader header;
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
offset = 0;
|
||||
std::ifstream is(filename, std::ifstream::binary);
|
||||
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
if(!is){
|
||||
LOG(ERROR) << "Failed to read " << filename;
|
||||
return false;
|
||||
}
|
||||
if (!header.Validate()) {
|
||||
return false;
|
||||
}
|
||||
header.SeekToDataChunk(is);
|
||||
if (!is) {
|
||||
return false;
|
||||
}
|
||||
if (!header.Validate()) {
|
||||
return false;
|
||||
}
|
||||
header.SeekToDataChunk(is);
|
||||
if (!is) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*sampling_rate = header.sample_rate;
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||
speech_len = header.subchunk2_size / 2;
|
||||
speech_char = (char *)malloc(header.subchunk2_size);
|
||||
memset(speech_char, 0, header.subchunk2_size);
|
||||
is.read(speech_char, header.subchunk2_size);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
|
||||
{
|
||||
WaveHeader header;
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
@ -441,6 +480,33 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
|
||||
|
||||
}
|
||||
|
||||
bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate)
|
||||
{
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
offset = 0;
|
||||
|
||||
FILE* fp;
|
||||
fp = fopen(filename, "rb");
|
||||
if (fp == nullptr)
|
||||
{
|
||||
LOG(ERROR) << "Failed to read " << filename;
|
||||
return false;
|
||||
}
|
||||
fseek(fp, 0, SEEK_END);
|
||||
uint32_t n_file_len = ftell(fp);
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
|
||||
speech_len = (n_file_len) / 2;
|
||||
speech_char = (char *)malloc(n_file_len);
|
||||
memset(speech_char, 0, n_file_len);
|
||||
fread(speech_char, sizeof(int16_t), n_file_len/2, fp);
|
||||
fclose(fp);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int Audio::FetchChunck(float *&dout, int len)
|
||||
{
|
||||
if (offset >= speech_align_len) {
|
||||
@ -541,7 +607,7 @@ void Audio::Split(OfflineStream* offline_stream)
|
||||
}
|
||||
|
||||
|
||||
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
|
||||
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
|
||||
{
|
||||
AudioFrame *frame;
|
||||
|
||||
@ -552,7 +618,7 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
|
||||
frame = NULL;
|
||||
|
||||
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
|
||||
vad_segments = vad_obj->Infer(pcm_data);
|
||||
vad_segments = vad_obj->Infer(pcm_data, input_finished);
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
198
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
Normal file
198
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
Normal file
@ -0,0 +1,198 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
|
||||
void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
|
||||
std::vector<float> &waves) {
|
||||
knf::OnlineFbank fbank(fbank_opts_);
|
||||
// cache merge
|
||||
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
|
||||
int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
|
||||
// Send the audio after the last frame shift position to the cache
|
||||
input_cache_.clear();
|
||||
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
|
||||
if (frame_number == 0) {
|
||||
return;
|
||||
}
|
||||
// Delete audio that haven't undergone fbank processing
|
||||
waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
|
||||
|
||||
std::vector<float> buf(waves.size());
|
||||
for (int32_t i = 0; i != waves.size(); ++i) {
|
||||
buf[i] = waves[i] * 32768;
|
||||
}
|
||||
fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
|
||||
// fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
|
||||
int32_t frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
vad_feats.emplace_back(frame_vector);
|
||||
}
|
||||
}
|
||||
|
||||
void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &vad_feats,
|
||||
vector<float> &waves, bool input_finished) {
|
||||
FbankKaldi(sample_rate, vad_feats, waves);
|
||||
// cache deal & online lfr,cmvn
|
||||
if (vad_feats.size() > 0) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
|
||||
}
|
||||
if (lfr_splice_cache_.empty()) {
|
||||
for (int i = 0; i < (lfr_m - 1) / 2; i++) {
|
||||
lfr_splice_cache_.emplace_back(vad_feats[0]);
|
||||
}
|
||||
}
|
||||
if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
|
||||
vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
|
||||
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
|
||||
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
|
||||
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
|
||||
int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
|
||||
waves.begin() + frame_from_waves * frame_shift_sample_length_);
|
||||
int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
|
||||
waves.erase(waves.begin() + sample_length, waves.end());
|
||||
} else {
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
|
||||
}
|
||||
} else {
|
||||
if (input_finished) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves = reserve_waveforms_;
|
||||
}
|
||||
vad_feats = lfr_splice_cache_;
|
||||
OnlineLfrCmvn(vad_feats, input_finished);
|
||||
}
|
||||
}
|
||||
if(input_finished){
|
||||
Reset();
|
||||
ResetCache();
|
||||
}
|
||||
}
|
||||
|
||||
int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
|
||||
vector<vector<float>> out_feats;
|
||||
int T = vad_feats.size();
|
||||
int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
|
||||
int lfr_splice_frame_idxs = T_lrf;
|
||||
vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
if (lfr_m <= T - i * lfr_n) {
|
||||
for (int j = 0; j < lfr_m; j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
} else {
|
||||
if (input_finished) {
|
||||
int num_padding = lfr_m - (T - i * lfr_n);
|
||||
for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
|
||||
}
|
||||
for (int j = 0; j < num_padding; j++) {
|
||||
p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
} else {
|
||||
lfr_splice_frame_idxs = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
|
||||
lfr_splice_cache_.clear();
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
|
||||
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < means_list_.size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
|
||||
}
|
||||
}
|
||||
vad_feats = out_feats;
|
||||
return lfr_splice_frame_idxs;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>>
|
||||
FsmnVadOnline::Infer(std::vector<float> &waves, bool input_finished) {
|
||||
std::vector<std::vector<float>> vad_feats;
|
||||
std::vector<std::vector<float>> vad_probs;
|
||||
ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished);
|
||||
fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
|
||||
|
||||
std::vector<std::vector<int>> vad_segments;
|
||||
vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_,
|
||||
vad_speech_noise_thres_, vad_sample_rate_);
|
||||
return vad_segments;
|
||||
}
|
||||
|
||||
void FsmnVadOnline::InitCache(){
|
||||
std::vector<float> cache_feats(128 * 19 * 1, 0);
|
||||
for (int i=0;i<4;i++){
|
||||
in_cache_.emplace_back(cache_feats);
|
||||
}
|
||||
};
|
||||
|
||||
void FsmnVadOnline::Reset(){
|
||||
in_cache_.clear();
|
||||
InitCache();
|
||||
};
|
||||
|
||||
void FsmnVadOnline::Test() {
|
||||
}
|
||||
|
||||
void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
|
||||
Ort::Env &env,
|
||||
std::vector<const char *> &vad_in_names,
|
||||
std::vector<const char *> &vad_out_names,
|
||||
knf::FbankOptions &fbank_opts,
|
||||
std::vector<float> &means_list,
|
||||
std::vector<float> &vars_list,
|
||||
int vad_sample_rate,
|
||||
int vad_silence_duration,
|
||||
int vad_max_len,
|
||||
double vad_speech_noise_thres) {
|
||||
vad_session_ = vad_session;
|
||||
vad_in_names_ = vad_in_names;
|
||||
vad_out_names_ = vad_out_names;
|
||||
fbank_opts_ = fbank_opts;
|
||||
means_list_ = means_list;
|
||||
vars_list_ = vars_list;
|
||||
vad_sample_rate_ = vad_sample_rate;
|
||||
vad_silence_duration_ = vad_silence_duration;
|
||||
vad_max_len_ = vad_max_len;
|
||||
vad_speech_noise_thres_ = vad_speech_noise_thres;
|
||||
}
|
||||
|
||||
FsmnVadOnline::~FsmnVadOnline() {
|
||||
}
|
||||
|
||||
FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{
|
||||
InitCache();
|
||||
InitOnline(fsmnvad_handle_->vad_session_,
|
||||
fsmnvad_handle_->env_,
|
||||
fsmnvad_handle_->vad_in_names_,
|
||||
fsmnvad_handle_->vad_out_names_,
|
||||
fsmnvad_handle_->fbank_opts_,
|
||||
fsmnvad_handle_->means_list_,
|
||||
fsmnvad_handle_->vars_list_,
|
||||
fsmnvad_handle_->vad_sample_rate_,
|
||||
fsmnvad_handle_->vad_silence_duration_,
|
||||
fsmnvad_handle_->vad_max_len_,
|
||||
fsmnvad_handle_->vad_speech_noise_thres_);
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
88
funasr/runtime/onnxruntime/src/fsmn-vad-online.h
Normal file
88
funasr/runtime/onnxruntime/src/fsmn-vad-online.h
Normal file
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
class FsmnVadOnline : public VadModel {
|
||||
/**
|
||||
* Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
* Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
* https://arxiv.org/abs/1803.05030
|
||||
*/
|
||||
|
||||
public:
|
||||
explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
|
||||
~FsmnVadOnline();
|
||||
void Test();
|
||||
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
|
||||
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
E2EVadModel vad_scorer = E2EVadModel();
|
||||
// std::unique_ptr<FsmnVad> fsmnvad_handle_;
|
||||
FsmnVad* fsmnvad_handle_ = nullptr;
|
||||
|
||||
void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
|
||||
std::vector<float> &waves);
|
||||
int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
|
||||
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
|
||||
void InitCache();
|
||||
void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
|
||||
Ort::Env &env,
|
||||
std::vector<const char *> &vad_in_names,
|
||||
std::vector<const char *> &vad_out_names,
|
||||
knf::FbankOptions &fbank_opts,
|
||||
std::vector<float> &means_list,
|
||||
std::vector<float> &vars_list,
|
||||
int vad_sample_rate,
|
||||
int vad_silence_duration,
|
||||
int vad_max_len,
|
||||
double vad_speech_noise_thres);
|
||||
|
||||
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
|
||||
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
|
||||
if (frame_num >= 1 && sample_length >= frame_sample_length)
|
||||
return frame_num;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
void ResetCache() {
|
||||
reserve_waveforms_.clear();
|
||||
input_cache_.clear();
|
||||
lfr_splice_cache_.clear();
|
||||
}
|
||||
|
||||
// from fsmnvad_handle_
|
||||
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options_;
|
||||
std::vector<const char *> vad_in_names_;
|
||||
std::vector<const char *> vad_out_names_;
|
||||
knf::FbankOptions fbank_opts_;
|
||||
std::vector<float> means_list_;
|
||||
std::vector<float> vars_list_;
|
||||
|
||||
std::vector<std::vector<float>> in_cache_;
|
||||
// The reserved waveforms by fbank
|
||||
std::vector<float> reserve_waveforms_;
|
||||
// waveforms reserved after last shift position
|
||||
std::vector<float> input_cache_;
|
||||
// lfr reserved cache
|
||||
std::vector<std::vector<float>> lfr_splice_cache_;
|
||||
|
||||
int vad_sample_rate_ = MODEL_SAMPLE_RATE;
|
||||
int vad_silence_duration_ = VAD_SILENCE_DURATION;
|
||||
int vad_max_len_ = VAD_MAX_LEN;
|
||||
double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
|
||||
int lfr_m = VAD_LFR_M;
|
||||
int lfr_n = VAD_LFR_N;
|
||||
int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
|
||||
int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
@ -37,14 +37,14 @@ void FsmnVad::LoadConfigFromYaml(const char* filename){
|
||||
this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
|
||||
this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
|
||||
|
||||
fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
|
||||
fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
|
||||
fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
|
||||
fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
|
||||
fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
|
||||
fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
|
||||
fbank_opts.energy_floor = 0;
|
||||
fbank_opts.mel_opts.debug_mel = false;
|
||||
fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
|
||||
fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
|
||||
fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
|
||||
fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
|
||||
fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
|
||||
fbank_opts_.energy_floor = 0;
|
||||
fbank_opts_.mel_opts.debug_mel = false;
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error when load argument from vad config YAML.";
|
||||
exit(-1);
|
||||
@ -55,6 +55,7 @@ void FsmnVad::ReadModel(const char* vad_model) {
|
||||
try {
|
||||
vad_session_ = std::make_shared<Ort::Session>(
|
||||
env_, vad_model, session_options_);
|
||||
LOG(INFO) << "Successfully load model from " << vad_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load vad onnx model: " << e.what();
|
||||
exit(0);
|
||||
@ -109,7 +110,9 @@ void FsmnVad::GetInputOutputInfo(
|
||||
|
||||
void FsmnVad::Forward(
|
||||
const std::vector<std::vector<float>> &chunk_feats,
|
||||
std::vector<std::vector<float>> *out_prob) {
|
||||
std::vector<std::vector<float>> *out_prob,
|
||||
std::vector<std::vector<float>> *in_cache,
|
||||
bool is_final) {
|
||||
Ort::MemoryInfo memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
@ -132,9 +135,9 @@ void FsmnVad::Forward(
|
||||
// 4 caches
|
||||
// cache node {batch,128,19,1}
|
||||
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
|
||||
for (int i = 0; i < in_cache_.size(); i++) {
|
||||
for (int i = 0; i < in_cache->size(); i++) {
|
||||
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
|
||||
memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
|
||||
memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
|
||||
}
|
||||
|
||||
// 4. Onnx infer
|
||||
@ -162,15 +165,17 @@ void FsmnVad::Forward(
|
||||
}
|
||||
|
||||
// get 4 caches outputs,each size is 128*19
|
||||
// for (int i = 1; i < 5; i++) {
|
||||
// float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
|
||||
// memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
|
||||
// }
|
||||
if(!is_final){
|
||||
for (int i = 1; i < 5; i++) {
|
||||
float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
|
||||
memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
|
||||
std::vector<float> &waves) {
|
||||
knf::OnlineFbank fbank(fbank_opts);
|
||||
knf::OnlineFbank fbank(fbank_opts_);
|
||||
|
||||
std::vector<float> buf(waves.size());
|
||||
for (int32_t i = 0; i != waves.size(); ++i) {
|
||||
@ -180,7 +185,7 @@ void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad
|
||||
int32_t frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
|
||||
std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
vad_feats.emplace_back(frame_vector);
|
||||
}
|
||||
}
|
||||
@ -205,7 +210,7 @@ void FsmnVad::LoadCmvn(const char *filename)
|
||||
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
|
||||
if (means_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < means_lines.size() - 1; j++) {
|
||||
means_list.push_back(stof(means_lines[j]));
|
||||
means_list_.push_back(stof(means_lines[j]));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -216,8 +221,8 @@ void FsmnVad::LoadCmvn(const char *filename)
|
||||
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
|
||||
if (vars_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < vars_lines.size() - 1; j++) {
|
||||
// vars_list.push_back(stof(vars_lines[j])*scale);
|
||||
vars_list.push_back(stof(vars_lines[j]));
|
||||
// vars_list_.push_back(stof(vars_lines[j])*scale);
|
||||
vars_list_.push_back(stof(vars_lines[j]));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -263,8 +268,8 @@ void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
|
||||
}
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < means_list.size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
|
||||
for (int j = 0; j < means_list_.size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
|
||||
}
|
||||
}
|
||||
vad_feats = out_feats;
|
||||
@ -276,7 +281,7 @@ FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
|
||||
std::vector<std::vector<float>> vad_probs;
|
||||
FbankKaldi(vad_sample_rate_, vad_feats, waves);
|
||||
LfrCmvn(vad_feats);
|
||||
Forward(vad_feats, &vad_probs);
|
||||
Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
|
||||
|
||||
E2EVadModel vad_scorer = E2EVadModel();
|
||||
std::vector<std::vector<int>> vad_segments;
|
||||
|
||||
@ -22,7 +22,30 @@ public:
|
||||
void Test();
|
||||
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
|
||||
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
|
||||
void Forward(
|
||||
const std::vector<std::vector<float>> &chunk_feats,
|
||||
std::vector<std::vector<float>> *out_prob,
|
||||
std::vector<std::vector<float>> *in_cache,
|
||||
bool is_final);
|
||||
void Reset();
|
||||
|
||||
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options_;
|
||||
std::vector<const char *> vad_in_names_;
|
||||
std::vector<const char *> vad_out_names_;
|
||||
std::vector<std::vector<float>> in_cache_;
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
std::vector<float> means_list_;
|
||||
std::vector<float> vars_list_;
|
||||
|
||||
int vad_sample_rate_ = MODEL_SAMPLE_RATE;
|
||||
int vad_silence_duration_ = VAD_SILENCE_DURATION;
|
||||
int vad_max_len_ = VAD_MAX_LEN;
|
||||
double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
|
||||
int lfr_m = VAD_LFR_M;
|
||||
int lfr_n = VAD_LFR_N;
|
||||
|
||||
private:
|
||||
|
||||
@ -37,31 +60,9 @@ private:
|
||||
std::vector<float> &waves);
|
||||
|
||||
void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
|
||||
|
||||
void Forward(
|
||||
const std::vector<std::vector<float>> &chunk_feats,
|
||||
std::vector<std::vector<float>> *out_prob);
|
||||
|
||||
void LoadCmvn(const char *filename);
|
||||
void InitCache();
|
||||
|
||||
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options_;
|
||||
std::vector<const char *> vad_in_names_;
|
||||
std::vector<const char *> vad_out_names_;
|
||||
std::vector<std::vector<float>> in_cache_;
|
||||
|
||||
knf::FbankOptions fbank_opts;
|
||||
std::vector<float> means_list;
|
||||
std::vector<float> vars_list;
|
||||
|
||||
int vad_sample_rate_ = MODEL_SAMPLE_RATE;
|
||||
int vad_silence_duration_ = VAD_SILENCE_DURATION;
|
||||
int vad_max_len_ = VAD_MAX_LEN;
|
||||
double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
|
||||
int lfr_m = VAD_LFR_M;
|
||||
int lfr_n = VAD_LFR_N;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -11,9 +11,15 @@ extern "C" {
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode)
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode);
|
||||
funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle)
|
||||
{
|
||||
funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle);
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -96,7 +102,7 @@ extern "C" {
|
||||
}
|
||||
|
||||
// APIs for VAD Infer
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
|
||||
{
|
||||
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
|
||||
if (!vad_obj)
|
||||
@ -110,13 +116,13 @@ extern "C" {
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
|
||||
vector<std::vector<int>> vad_segments;
|
||||
audio.Split(vad_obj, vad_segments);
|
||||
audio.Split(vad_obj, vad_segments, input_finished);
|
||||
p_result->segments = new vector<std::vector<int>>(vad_segments);
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
|
||||
_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate)
|
||||
{
|
||||
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
|
||||
if (!vad_obj)
|
||||
@ -139,7 +145,7 @@ extern "C" {
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
|
||||
vector<std::vector<int>> vad_segments;
|
||||
audio.Split(vad_obj, vad_segments);
|
||||
audio.Split(vad_obj, vad_segments, true);
|
||||
p_result->segments = new vector<std::vector<int>>(vad_segments);
|
||||
|
||||
return p_result;
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
* Contributed by zhuzizyf(China Telecom).
|
||||
*/
|
||||
|
||||
#include "online-feature.h"
|
||||
#include <utility>
|
||||
|
||||
namespace funasr {
|
||||
OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
|
||||
std::vector<std::vector<float>> cmvns)
|
||||
: sample_rate_(sample_rate),
|
||||
fbank_opts_(std::move(fbank_opts)),
|
||||
lfr_m_(lfr_m),
|
||||
lfr_n_(lfr_n),
|
||||
cmvns_(std::move(cmvns)) {
|
||||
frame_sample_length_ = sample_rate_ / 1000 * 25;;
|
||||
frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
|
||||
}
|
||||
|
||||
void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
|
||||
vector<float> waves, bool input_finished) {
|
||||
input_finished_ = input_finished;
|
||||
OnlineFbank(vad_feats, waves);
|
||||
// cache deal & online lfr,cmvn
|
||||
if (vad_feats.size() > 0) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
|
||||
}
|
||||
if (lfr_splice_cache_.empty()) {
|
||||
for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
|
||||
lfr_splice_cache_.emplace_back(vad_feats[0]);
|
||||
}
|
||||
}
|
||||
if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
|
||||
vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
|
||||
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
|
||||
int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
|
||||
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
|
||||
int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
|
||||
waves.begin() + frame_from_waves * frame_shift_sample_length_);
|
||||
int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
|
||||
waves.erase(waves.begin() + sample_length, waves.end());
|
||||
} else {
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
|
||||
}
|
||||
|
||||
} else {
|
||||
if (input_finished_) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves = reserve_waveforms_;
|
||||
}
|
||||
vad_feats = lfr_splice_cache_;
|
||||
OnlineLfrCmvn(vad_feats);
|
||||
ResetCache();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
|
||||
vector<vector<float>> out_feats;
|
||||
int T = vad_feats.size();
|
||||
int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
|
||||
int lfr_splice_frame_idxs = T_lrf;
|
||||
vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
if (lfr_m_ <= T - i * lfr_n_) {
|
||||
for (int j = 0; j < lfr_m_; j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
} else {
|
||||
if (input_finished_) {
|
||||
int num_padding = lfr_m_ - (T - i * lfr_n_);
|
||||
for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
|
||||
}
|
||||
for (int j = 0; j < num_padding; j++) {
|
||||
p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
} else {
|
||||
lfr_splice_frame_idxs = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
|
||||
lfr_splice_cache_.clear();
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
|
||||
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < cmvns_[0].size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
|
||||
}
|
||||
}
|
||||
vad_feats = out_feats;
|
||||
return lfr_splice_frame_idxs;
|
||||
}
|
||||
|
||||
void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
|
||||
vector<float> &waves) {
|
||||
|
||||
knf::OnlineFbank fbank(fbank_opts_);
|
||||
// cache merge
|
||||
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
|
||||
int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
|
||||
// Send the audio after the last frame shift position to the cache
|
||||
input_cache_.clear();
|
||||
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
|
||||
if (frame_number == 0) {
|
||||
return;
|
||||
}
|
||||
// Delete audio that haven't undergone fbank processing
|
||||
waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
|
||||
|
||||
fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
|
||||
int32_t frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
vad_feats.emplace_back(frame_vector);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
@ -1,58 +0,0 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
* Contributed by zhuzizyf(China Telecom).
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "precomp.h"
|
||||
|
||||
using namespace std;
|
||||
namespace funasr {
|
||||
class OnlineFeature {
|
||||
|
||||
public:
|
||||
OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
|
||||
std::vector<std::vector<float>> cmvns_);
|
||||
|
||||
void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
|
||||
|
||||
private:
|
||||
void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
|
||||
int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
|
||||
|
||||
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
|
||||
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
|
||||
if (frame_num >= 1 && sample_length >= frame_sample_length)
|
||||
return frame_num;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ResetCache() {
|
||||
reserve_waveforms_.clear();
|
||||
input_cache_.clear();
|
||||
lfr_splice_cache_.clear();
|
||||
input_finished_ = false;
|
||||
|
||||
}
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
// The reserved waveforms by fbank
|
||||
std::vector<float> reserve_waveforms_;
|
||||
// waveforms reserved after last shift position
|
||||
std::vector<float> input_cache_;
|
||||
// lfr reserved cache
|
||||
std::vector<std::vector<float>> lfr_splice_cache_;
|
||||
std::vector<std::vector<float>> cmvns_;
|
||||
|
||||
int sample_rate_ = 16000;
|
||||
int frame_sample_length_ = sample_rate_ / 1000 * 25;;
|
||||
int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
|
||||
int lfr_m_;
|
||||
int lfr_n_;
|
||||
bool input_finished_ = false;
|
||||
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
@ -18,7 +18,7 @@ namespace funasr {
|
||||
//std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions fbank_opts;
|
||||
|
||||
Vocab* vocab;
|
||||
Vocab* vocab = nullptr;
|
||||
vector<float> means_list;
|
||||
vector<float> vars_list;
|
||||
const float scale = 22.6274169979695;
|
||||
@ -30,7 +30,7 @@ namespace funasr {
|
||||
void ApplyCmvn(vector<float> *v);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums);
|
||||
|
||||
std::shared_ptr<Ort::Session> m_session;
|
||||
std::shared_ptr<Ort::Session> m_session = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options;
|
||||
|
||||
|
||||
@ -36,8 +36,9 @@ using namespace std;
|
||||
#include "offline-stream.h"
|
||||
#include "tokenizer.h"
|
||||
#include "ct-transformer.h"
|
||||
#include "fsmn-vad.h"
|
||||
#include "e2e-vad.h"
|
||||
#include "fsmn-vad.h"
|
||||
#include "fsmn-vad-online.h"
|
||||
#include "vocab.h"
|
||||
#include "audio.h"
|
||||
#include "tensor.h"
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
VadModel *mm;
|
||||
if(mode == FSMN_VAD_OFFLINE){
|
||||
mm = new FsmnVad();
|
||||
}else{
|
||||
LOG(ERROR)<<"Online fsmn vad not imp!";
|
||||
}
|
||||
mm = new FsmnVad();
|
||||
|
||||
string vad_model_path;
|
||||
string vad_cmvn_path;
|
||||
@ -25,4 +21,11 @@ VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thr
|
||||
return mm;
|
||||
}
|
||||
|
||||
VadModel *CreateVadModel(void* fsmnvad_handle)
|
||||
{
|
||||
VadModel *mm;
|
||||
mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle);
|
||||
return mm;
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
Loading…
Reference in New Issue
Block a user