mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix wavhead reader; modify punc input to int32; add vad/punc/offline-stream apis; modify option parser
This commit is contained in:
parent
865a1bf208
commit
11f0ed89af
@ -1,10 +1,10 @@
|
||||
|
||||
#ifndef AUDIO_H
|
||||
#define AUDIO_H
|
||||
|
||||
#include <queue>
|
||||
#include <stdint.h>
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
#include "offline-stream.h"
|
||||
|
||||
#ifndef WAV_HEADER_SIZE
|
||||
#define WAV_HEADER_SIZE 44
|
||||
@ -54,7 +54,8 @@ class Audio {
|
||||
int FetchChunck(float *&dout, int len);
|
||||
int Fetch(float *&dout, int &len, int &flag);
|
||||
void Padding();
|
||||
void Split(Model* recog_obj);
|
||||
void Split(OfflineStream* offline_streamj);
|
||||
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
|
||||
float GetTimeLen();
|
||||
int GetQueueSize() { return (int)frame_queue.size(); }
|
||||
};
|
||||
|
||||
@ -12,20 +12,37 @@
|
||||
#define MODEL_SAMPLE_RATE 16000
|
||||
#endif
|
||||
|
||||
// model path
|
||||
#define VAD_MODEL_PATH "vad-model"
|
||||
#define VAD_CMVN_PATH "vad-cmvn"
|
||||
#define VAD_CONFIG_PATH "vad-config"
|
||||
#define AM_MODEL_PATH "am-model"
|
||||
#define AM_CMVN_PATH "am-cmvn"
|
||||
#define AM_CONFIG_PATH "am-config"
|
||||
#define PUNC_MODEL_PATH "punc-model"
|
||||
#define PUNC_CONFIG_PATH "punc-config"
|
||||
// parser option
|
||||
#define MODEL_DIR "model-dir"
|
||||
#define VAD_DIR "vad-dir"
|
||||
#define PUNC_DIR "punc-dir"
|
||||
#define QUANTIZE "quantize"
|
||||
#define VAD_QUANT "vad-quant"
|
||||
#define PUNC_QUANT "punc-quant"
|
||||
|
||||
#define WAV_PATH "wav-path"
|
||||
#define WAV_SCP "wav-scp"
|
||||
#define TXT_PATH "txt-path"
|
||||
#define THREAD_NUM "thread-num"
|
||||
#define PORT_ID "port-id"
|
||||
|
||||
// #define VAD_MODEL_PATH "vad-model"
|
||||
// #define VAD_CMVN_PATH "vad-cmvn"
|
||||
// #define VAD_CONFIG_PATH "vad-config"
|
||||
// #define AM_MODEL_PATH "am-model"
|
||||
// #define AM_CMVN_PATH "am-cmvn"
|
||||
// #define AM_CONFIG_PATH "am-config"
|
||||
// #define PUNC_MODEL_PATH "punc-model"
|
||||
// #define PUNC_CONFIG_PATH "punc-config"
|
||||
|
||||
#define MODEL_NAME "model.onnx"
|
||||
#define QUANT_MODEL_NAME "model_quant.onnx"
|
||||
#define VAD_CMVN_NAME "vad.mvn"
|
||||
#define VAD_CONFIG_NAME "vad.yaml"
|
||||
#define AM_CMVN_NAME "am.mvn"
|
||||
#define AM_CONFIG_NAME "config.yaml"
|
||||
#define PUNC_CONFIG_NAME "punc.yaml"
|
||||
|
||||
// vad
|
||||
#ifndef VAD_SILENCE_DURATION
|
||||
#define VAD_SILENCE_DURATION 800
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef WIN32
|
||||
#ifdef _FUNASR_API_EXPORT
|
||||
@ -47,8 +48,8 @@ typedef enum {
|
||||
|
||||
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
|
||||
|
||||
// // ASR
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
// ASR
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
@ -62,12 +63,23 @@ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
|
||||
|
||||
// VAD
|
||||
_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI FUNASR_RESULT FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI std::vector<std::vector<int>>* FunVadGetResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI void FunVadFreeResult(FUNASR_RESULT result);
|
||||
_FUNASRAPI void FunVadUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI const float FunVadGetRetSnippetTime(FUNASR_RESULT result);
|
||||
|
||||
// PUNC
|
||||
_FUNASRAPI FUNASR_HANDLE FunPuncInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI const std::string FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI void FunPuncUninit(FUNASR_HANDLE handle);
|
||||
|
||||
//OfflineStream
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_RESULT FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
|
||||
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
||||
@ -9,13 +9,10 @@ class Model {
|
||||
public:
|
||||
virtual ~Model(){};
|
||||
virtual void Reset() = 0;
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
|
||||
virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
|
||||
virtual std::string Forward(float *din, int len, int flag) = 0;
|
||||
virtual std::string Rescoring() = 0;
|
||||
virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
|
||||
virtual std::string AddPunc(const char* sz_input)=0;
|
||||
virtual bool UseVad() =0;
|
||||
virtual bool UsePunc() =0;
|
||||
};
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
|
||||
|
||||
28
funasr/runtime/onnxruntime/include/offline-stream.h
Normal file
28
funasr/runtime/onnxruntime/include/offline-stream.h
Normal file
@ -0,0 +1,28 @@
|
||||
#ifndef OFFLINE_STREAM_H
|
||||
#define OFFLINE_STREAM_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "model.h"
|
||||
#include "punc-model.h"
|
||||
#include "vad-model.h"
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
~OfflineStream(){};
|
||||
|
||||
std::unique_ptr<VadModel> vad_handle;
|
||||
std::unique_ptr<Model> asr_handle;
|
||||
std::unique_ptr<PuncModel> punc_handle;
|
||||
bool UseVad(){return use_vad;};
|
||||
bool UsePunc(){return use_punc;};
|
||||
|
||||
private:
|
||||
bool use_vad=false;
|
||||
bool use_punc=false;
|
||||
};
|
||||
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
|
||||
#endif
|
||||
18
funasr/runtime/onnxruntime/include/punc-model.h
Normal file
18
funasr/runtime/onnxruntime/include/punc-model.h
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
#ifndef PUNC_MODEL_H
|
||||
#define PUNC_MODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
class PuncModel {
|
||||
public:
|
||||
virtual ~PuncModel(){};
|
||||
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
|
||||
virtual std::vector<int> Infer(std::vector<int32_t> input_data)=0;
|
||||
virtual std::string AddPunc(const char* sz_input)=0;
|
||||
};
|
||||
|
||||
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
#endif
|
||||
27
funasr/runtime/onnxruntime/include/vad-model.h
Normal file
27
funasr/runtime/onnxruntime/include/vad-model.h
Normal file
@ -0,0 +1,27 @@
|
||||
|
||||
#ifndef VAD_MODEL_H
|
||||
#define VAD_MODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
class VadModel {
|
||||
public:
|
||||
virtual ~VadModel(){};
|
||||
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
|
||||
virtual std::vector<std::vector<int>> Infer(const std::vector<float> &waves)=0;
|
||||
virtual void ReadModel(const char* vad_model)=0;
|
||||
virtual void LoadConfigFromYaml(const char* filename)=0;
|
||||
virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
|
||||
const std::vector<float> &waves)=0;
|
||||
virtual std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats)=0;
|
||||
virtual void Forward(
|
||||
const std::vector<std::vector<float>> &chunk_feats,
|
||||
std::vector<std::vector<float>> *out_prob)=0;
|
||||
virtual void LoadCmvn(const char *filename)=0;
|
||||
virtual void InitCache()=0;
|
||||
};
|
||||
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
#endif
|
||||
@ -26,7 +26,11 @@ include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
|
||||
|
||||
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
|
||||
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
|
||||
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
|
||||
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
|
||||
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
|
||||
|
||||
|
||||
@ -237,6 +237,15 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
|
||||
LOG(ERROR) << "Failed to read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!header.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
header.SeekToDataChunk(is);
|
||||
if (!is) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*sampling_rate = header.sample_rate;
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
@ -494,7 +503,7 @@ void Audio::Padding()
|
||||
delete frame;
|
||||
}
|
||||
|
||||
void Audio::Split(Model* recog_obj)
|
||||
void Audio::Split(OfflineStream* offline_stream)
|
||||
{
|
||||
AudioFrame *frame;
|
||||
|
||||
@ -505,7 +514,7 @@ void Audio::Split(Model* recog_obj)
|
||||
frame = NULL;
|
||||
|
||||
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
|
||||
vector<std::vector<int>> vad_segments = recog_obj->VadSeg(pcm_data);
|
||||
vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
|
||||
int seg_sample = MODEL_SAMPLE_RATE/1000;
|
||||
for(vector<int> segment:vad_segments)
|
||||
{
|
||||
@ -518,3 +527,18 @@ void Audio::Split(Model* recog_obj)
|
||||
frame = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
|
||||
{
|
||||
AudioFrame *frame;
|
||||
|
||||
frame = frame_queue.front();
|
||||
frame_queue.pop();
|
||||
int sp_len = frame->GetLen();
|
||||
delete frame;
|
||||
frame = NULL;
|
||||
|
||||
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
|
||||
vad_segments = vad_obj->Infer(pcm_data);
|
||||
}
|
||||
@ -6,6 +6,12 @@ typedef struct
|
||||
float snippet_time;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
std::vector<std::vector<int>>* segments;
|
||||
float snippet_time;
|
||||
}FUNASR_VAD_RESULT;
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <codecvt>
|
||||
|
||||
@ -54,7 +54,7 @@ string CTTransformer::AddPunc(const char* sz_input)
|
||||
int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
|
||||
int nCurBatch = -1;
|
||||
int nSentEnd = -1, nLastCommaIndex = -1;
|
||||
vector<int64_t> RemainIDs; //
|
||||
vector<int32_t> RemainIDs; //
|
||||
vector<string> RemainStr; //
|
||||
vector<int> NewPunctuation; //
|
||||
vector<string> NewString; //
|
||||
@ -64,7 +64,7 @@ string CTTransformer::AddPunc(const char* sz_input)
|
||||
for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
|
||||
{
|
||||
nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
|
||||
vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
|
||||
vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
|
||||
vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
|
||||
InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
|
||||
InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
|
||||
@ -141,12 +141,13 @@ string CTTransformer::AddPunc(const char* sz_input)
|
||||
return strResult;
|
||||
}
|
||||
|
||||
vector<int> CTTransformer::Infer(vector<int64_t> input_data)
|
||||
vector<int> CTTransformer::Infer(vector<int32_t> input_data)
|
||||
{
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
vector<int> punction;
|
||||
std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
|
||||
Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
|
||||
Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
|
||||
m_memoryInfo,
|
||||
input_data.data(),
|
||||
input_data.size(),
|
||||
input_shape_.data(),
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
class CTTransformer {
|
||||
class CTTransformer : public PuncModel {
|
||||
/**
|
||||
* Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
* CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
@ -27,6 +27,6 @@ public:
|
||||
CTTransformer();
|
||||
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
|
||||
~CTTransformer();
|
||||
vector<int> Infer(vector<int64_t> input_data);
|
||||
vector<int> Infer(vector<int32_t> input_data);
|
||||
string AddPunc(const char* sz_input);
|
||||
};
|
||||
|
||||
@ -6,8 +6,8 @@
|
||||
#include <fstream>
|
||||
#include "precomp.h"
|
||||
|
||||
void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) {
|
||||
session_options_.SetIntraOpNumThreads(1);
|
||||
void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
|
||||
session_options_.SetIntraOpNumThreads(thread_num);
|
||||
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
|
||||
session_options_.DisableCpuMemArena();
|
||||
|
||||
@ -296,5 +296,8 @@ void FsmnVad::Reset(){
|
||||
void FsmnVad::Test() {
|
||||
}
|
||||
|
||||
FsmnVad::~FsmnVad() {
|
||||
}
|
||||
|
||||
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
class FsmnVad {
|
||||
class FsmnVad : public VadModel {
|
||||
/**
|
||||
* Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
* Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
@ -17,9 +17,9 @@ class FsmnVad {
|
||||
|
||||
public:
|
||||
FsmnVad();
|
||||
~FsmnVad();
|
||||
void Test();
|
||||
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config);
|
||||
|
||||
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
|
||||
std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
|
||||
void Reset();
|
||||
|
||||
|
||||
98
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
Normal file
98
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <glog/logging.h>
|
||||
#include "libfunasrapi.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", false, "", "string");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(txt_path);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(txt_path, TXT_PATH, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE punc_hanlde=FunPuncInit(model_path, thread_num);
|
||||
|
||||
if (!punc_hanlde)
|
||||
{
|
||||
LOG(ERROR) << "FunASR init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read txt_path
|
||||
vector<string> txt_list;
|
||||
|
||||
if(model_path.find(TXT_PATH)!=model_path.end()){
|
||||
ifstream in(model_path.at(TXT_PATH));
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
txt_list.emplace_back(line);
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
|
||||
long taking_micros = 0;
|
||||
for(auto& txt_str : txt_list){
|
||||
gettimeofday(&start, NULL);
|
||||
string result=FunPuncInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO)<<"Results: "<<result;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
FunPuncUninit(punc_hanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -91,41 +91,21 @@ int main(int argc, char *argv[])
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", false, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
|
||||
|
||||
cmd.add(vad_model);
|
||||
cmd.add(vad_cmvn);
|
||||
cmd.add(vad_config);
|
||||
cmd.add(am_model);
|
||||
cmd.add(am_cmvn);
|
||||
cmd.add(am_config);
|
||||
cmd.add(punc_model);
|
||||
cmd.add(punc_config);
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_scp);
|
||||
cmd.add(thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(vad_model, VAD_MODEL_PATH, model_path);
|
||||
GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
|
||||
GetValue(vad_config, VAD_CONFIG_PATH, model_path);
|
||||
GetValue(am_model, AM_MODEL_PATH, model_path);
|
||||
GetValue(am_cmvn, AM_CMVN_PATH, model_path);
|
||||
GetValue(am_config, AM_CONFIG_PATH, model_path);
|
||||
GetValue(punc_model, PUNC_MODEL_PATH, model_path);
|
||||
GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(wav_scp, WAV_SCP, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
|
||||
143
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
Normal file
143
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
Normal file
@ -0,0 +1,143 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <glog/logging.h>
|
||||
#include "libfunasrapi.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
void print_segs(vector<vector<int>>* vec) {
|
||||
string seg_out="[";
|
||||
for (int i = 0; i < vec->size(); i++) {
|
||||
vector<int> inner_vec = (*vec)[i];
|
||||
seg_out += "[";
|
||||
for (int j = 0; j < inner_vec.size(); j++) {
|
||||
seg_out += to_string(inner_vec[j]);
|
||||
if (j != inner_vec.size() - 1) {
|
||||
seg_out += ",";
|
||||
}
|
||||
}
|
||||
seg_out += "]";
|
||||
if (i != vec->size() - 1) {
|
||||
seg_out += ",";
|
||||
}
|
||||
}
|
||||
seg_out += "]";
|
||||
LOG(INFO)<<seg_out;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(wav_scp);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(wav_scp, WAV_SCP, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE vad_hanlde=FunVadInit(model_path, thread_num);
|
||||
|
||||
if (!vad_hanlde)
|
||||
{
|
||||
LOG(ERROR) << "FunVad init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path and wav_scp
|
||||
vector<string> wav_list;
|
||||
|
||||
if(model_path.find(WAV_PATH)!=model_path.end()){
|
||||
wav_list.emplace_back(model_path.at(WAV_PATH));
|
||||
}
|
||||
if(model_path.find(WAV_SCP)!=model_path.end()){
|
||||
ifstream in(model_path.at(WAV_SCP));
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
}
|
||||
in.close();
|
||||
}
|
||||
|
||||
float snippet_time = 0.0f;
|
||||
long taking_micros = 0;
|
||||
for(auto& wav_file : wav_list){
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunVadWavFile(vad_hanlde, wav_file.c_str(), RASR_NONE, NULL);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
if (result)
|
||||
{
|
||||
vector<std::vector<int>>* vad_segments = FunVadGetResult(result, 0);
|
||||
print_segs(vad_segments);
|
||||
snippet_time += FunVadGetRetSnippetTime(result);
|
||||
FunVadFreeResult(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunVadUninit(vad_hanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -28,55 +28,46 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std:
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
|
||||
|
||||
cmd.add(vad_model);
|
||||
cmd.add(vad_cmvn);
|
||||
cmd.add(vad_config);
|
||||
cmd.add(am_model);
|
||||
cmd.add(am_cmvn);
|
||||
cmd.add(am_config);
|
||||
cmd.add(punc_model);
|
||||
cmd.add(punc_config);
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(wav_scp);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(vad_model, VAD_MODEL_PATH, model_path);
|
||||
GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
|
||||
GetValue(vad_config, VAD_CONFIG_PATH, model_path);
|
||||
GetValue(am_model, AM_MODEL_PATH, model_path);
|
||||
GetValue(am_cmvn, AM_CMVN_PATH, model_path);
|
||||
GetValue(am_config, AM_CONFIG_PATH, model_path);
|
||||
GetValue(punc_model, PUNC_MODEL_PATH, model_path);
|
||||
GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(wav_scp, WAV_SCP, model_path);
|
||||
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE asr_hanlde=FunASRInit(model_path, thread_num);
|
||||
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
|
||||
|
||||
if (!asr_hanlde)
|
||||
{
|
||||
@ -116,7 +107,7 @@ int main(int argc, char *argv[])
|
||||
long taking_micros = 0;
|
||||
for(auto& wav_file : wav_list){
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
|
||||
FUNASR_RESULT result=FunOfflineStream(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -124,8 +115,7 @@ int main(int argc, char *argv[])
|
||||
if (result)
|
||||
{
|
||||
string msg = FunASRGetResult(result, 0);
|
||||
setbuf(stdout, NULL);
|
||||
printf("Result: %s \n", msg.c_str());
|
||||
LOG(INFO)<<"Result: "<<msg;
|
||||
snippet_time += FunASRGetRetSnippetTime(result);
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
@ -138,7 +128,7 @@ int main(int argc, char *argv[])
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunASRUninit(asr_hanlde);
|
||||
FunOfflineUninit(asr_hanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// APIs for funasr
|
||||
// APIs for Init
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
Model* mm = CreateModel(model_path, thread_num);
|
||||
@ -13,10 +13,23 @@ extern "C" {
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
Model* mm = CreateModel(model_path, thread_num);
|
||||
VadModel* mm = CreateVadModel(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunPuncInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
PuncModel* mm = CreatePuncModel(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
OfflineStream* mm = CreateOfflineStream(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
// APIs for ASR Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
|
||||
{
|
||||
Model* recog_obj = (Model*)handle;
|
||||
@ -27,9 +40,6 @@ extern "C" {
|
||||
Audio audio(1);
|
||||
if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
|
||||
return nullptr;
|
||||
if(recog_obj->UseVad()){
|
||||
audio.Split(recog_obj);
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
@ -45,10 +55,6 @@ extern "C" {
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(recog_obj->UsePunc()){
|
||||
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
@ -62,9 +68,6 @@ extern "C" {
|
||||
Audio audio(1);
|
||||
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
|
||||
return nullptr;
|
||||
if(recog_obj->UseVad()){
|
||||
audio.Split(recog_obj);
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
@ -80,10 +83,6 @@ extern "C" {
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(recog_obj->UsePunc()){
|
||||
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
@ -97,9 +96,6 @@ extern "C" {
|
||||
Audio audio(1);
|
||||
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
|
||||
return nullptr;
|
||||
if(recog_obj->UseVad()){
|
||||
audio.Split(recog_obj);
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
@ -115,10 +111,6 @@ extern "C" {
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(recog_obj->UsePunc()){
|
||||
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
@ -133,9 +125,6 @@ extern "C" {
|
||||
Audio audio(1);
|
||||
if(!audio.LoadWav(sz_wavfile, &sampling_rate))
|
||||
return nullptr;
|
||||
if(recog_obj->UseVad()){
|
||||
audio.Split(recog_obj);
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
@ -151,8 +140,74 @@ extern "C" {
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(recog_obj->UsePunc()){
|
||||
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
// APIs for VAD Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
|
||||
{
|
||||
VadModel* vad_obj = (VadModel*)handle;
|
||||
if (!vad_obj)
|
||||
return nullptr;
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
Audio audio(1);
|
||||
if(!audio.LoadWav(sz_wavfile, &sampling_rate))
|
||||
return nullptr;
|
||||
|
||||
FUNASR_VAD_RESULT* p_result = new FUNASR_VAD_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
|
||||
vector<std::vector<int>> vad_segments;
|
||||
audio.Split(vad_obj, vad_segments);
|
||||
p_result->segments = new vector<std::vector<int>>(vad_segments);
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
// APIs for PUNC Infer
|
||||
_FUNASRAPI const std::string FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
|
||||
{
|
||||
PuncModel* punc_obj = (PuncModel*)handle;
|
||||
if (!punc_obj)
|
||||
return nullptr;
|
||||
|
||||
string punc_res = punc_obj->AddPunc(sz_sentence);
|
||||
return punc_res;
|
||||
}
|
||||
|
||||
// APIs for Offline-stream Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
|
||||
{
|
||||
OfflineStream* offline_stream = (OfflineStream*)handle;
|
||||
if (!offline_stream)
|
||||
return nullptr;
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
Audio audio(1);
|
||||
if(!audio.LoadWav(sz_wavfile, &sampling_rate))
|
||||
return nullptr;
|
||||
if(offline_stream->UseVad()){
|
||||
audio.Split(offline_stream);
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
|
||||
p_result->msg+= msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(offline_stream->UsePunc()){
|
||||
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
|
||||
@ -167,7 +222,7 @@ extern "C" {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
// APIs for GetRetSnippetTime
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
|
||||
{
|
||||
if (!result)
|
||||
@ -176,6 +231,15 @@ extern "C" {
|
||||
return ((FUNASR_RECOG_RESULT*)result)->snippet_time;
|
||||
}
|
||||
|
||||
_FUNASRAPI const float FunVadGetRetSnippetTime(FUNASR_RESULT result)
|
||||
{
|
||||
if (!result)
|
||||
return 0.0f;
|
||||
|
||||
return ((FUNASR_VAD_RESULT*)result)->snippet_time;
|
||||
}
|
||||
|
||||
// APIs for GetResult
|
||||
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
|
||||
{
|
||||
FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result;
|
||||
@ -185,6 +249,16 @@ extern "C" {
|
||||
return p_result->msg.c_str();
|
||||
}
|
||||
|
||||
_FUNASRAPI vector<std::vector<int>>* FunVadGetResult(FUNASR_RESULT result,int n_index)
|
||||
{
|
||||
FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result;
|
||||
if(!p_result)
|
||||
return nullptr;
|
||||
|
||||
return p_result->segments;
|
||||
}
|
||||
|
||||
// APIs for FreeResult
|
||||
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
|
||||
{
|
||||
if (result)
|
||||
@ -193,6 +267,19 @@ extern "C" {
|
||||
}
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunVadFreeResult(FUNASR_RESULT result)
|
||||
{
|
||||
FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result;
|
||||
if (p_result)
|
||||
{
|
||||
if(p_result->segments){
|
||||
delete p_result->segments;
|
||||
}
|
||||
delete p_result;
|
||||
}
|
||||
}
|
||||
|
||||
// APIs for Uninit
|
||||
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
Model* recog_obj = (Model*)handle;
|
||||
@ -203,6 +290,36 @@ extern "C" {
|
||||
delete recog_obj;
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunVadUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
VadModel* recog_obj = (VadModel*)handle;
|
||||
|
||||
if (!recog_obj)
|
||||
return;
|
||||
|
||||
delete recog_obj;
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunPuncUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
PuncModel* punc_obj = (PuncModel*)handle;
|
||||
|
||||
if (!punc_obj)
|
||||
return;
|
||||
|
||||
delete punc_obj;
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
OfflineStream* offline_stream = (OfflineStream*)handle;
|
||||
|
||||
if (!offline_stream)
|
||||
return;
|
||||
|
||||
delete offline_stream;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
|
||||
@ -2,7 +2,19 @@
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
string am_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
|
||||
Model *mm;
|
||||
mm = new paraformer::Paraformer(model_path, thread_num);
|
||||
mm = new paraformer::Paraformer();
|
||||
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
61
funasr/runtime/onnxruntime/src/offline-stream.cpp
Normal file
61
funasr/runtime/onnxruntime/src/offline-stream.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include "precomp.h"
|
||||
|
||||
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
// VAD model
|
||||
if(model_path.find(VAD_DIR) != model_path.end()){
|
||||
use_vad = true;
|
||||
string vad_model_path;
|
||||
string vad_cmvn_path;
|
||||
string vad_config_path;
|
||||
|
||||
vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
|
||||
if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
|
||||
vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
|
||||
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
|
||||
vad_handle = make_unique<FsmnVad>();
|
||||
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
|
||||
}
|
||||
|
||||
// AM model
|
||||
if(model_path.find(MODEL_DIR) != model_path.end()){
|
||||
string am_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
|
||||
asr_handle = make_unique<Paraformer>();
|
||||
asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
if(model_path.find(PUNC_DIR) != model_path.end()){
|
||||
use_punc = true;
|
||||
string punc_model_path;
|
||||
string punc_config_path;
|
||||
|
||||
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
|
||||
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
|
||||
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
|
||||
|
||||
punc_handle = make_unique<CTTransformer>();
|
||||
punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
|
||||
}
|
||||
}
|
||||
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
OfflineStream *mm;
|
||||
mm = new OfflineStream(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
@ -8,65 +8,11 @@
|
||||
using namespace std;
|
||||
using namespace paraformer;
|
||||
|
||||
Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
|
||||
Paraformer::Paraformer()
|
||||
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
|
||||
|
||||
// VAD model
|
||||
if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
|
||||
use_vad = true;
|
||||
string vad_model_path;
|
||||
string vad_cmvn_path;
|
||||
string vad_config_path;
|
||||
|
||||
try{
|
||||
vad_model_path = model_path.at(VAD_MODEL_PATH);
|
||||
vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
|
||||
vad_config_path = model_path.at(VAD_CONFIG_PATH);
|
||||
}catch(const out_of_range& e){
|
||||
LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what();
|
||||
exit(0);
|
||||
}
|
||||
vad_handle = make_unique<FsmnVad>();
|
||||
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path);
|
||||
}
|
||||
|
||||
// AM model
|
||||
if(model_path.find(AM_MODEL_PATH) != model_path.end()){
|
||||
string am_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
try{
|
||||
am_model_path = model_path.at(AM_MODEL_PATH);
|
||||
am_cmvn_path = model_path.at(AM_CMVN_PATH);
|
||||
am_config_path = model_path.at(AM_CONFIG_PATH);
|
||||
}catch(const out_of_range& e){
|
||||
LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
|
||||
exit(0);
|
||||
}
|
||||
InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
|
||||
use_punc = true;
|
||||
string punc_model_path;
|
||||
string punc_config_path;
|
||||
|
||||
try{
|
||||
punc_model_path = model_path.at(PUNC_MODEL_PATH);
|
||||
punc_config_path = model_path.at(PUNC_CONFIG_PATH);
|
||||
}catch(const out_of_range& e){
|
||||
LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
punc_handle = make_unique<CTTransformer>();
|
||||
punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
|
||||
}
|
||||
}
|
||||
|
||||
void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
// knf options
|
||||
fbank_opts.frame_opts.dither = 0;
|
||||
fbank_opts.mel_opts.num_bins = 80;
|
||||
@ -120,14 +66,6 @@ void Paraformer::Reset()
|
||||
{
|
||||
}
|
||||
|
||||
vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
|
||||
return vad_handle->Infer(pcm_data);
|
||||
}
|
||||
|
||||
string Paraformer::AddPunc(const char* sz_input){
|
||||
return punc_handle->AddPunc(sz_input);
|
||||
}
|
||||
|
||||
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
|
||||
knf::OnlineFbank fbank_(fbank_opts);
|
||||
fbank_.AcceptWaveform(sample_rate, waves, len);
|
||||
@ -282,7 +220,7 @@ string Paraformer::Forward(float* din, int len, int flag)
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
{
|
||||
printf(e.what());
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -291,12 +229,12 @@ string Paraformer::Forward(float* din, int len, int flag)
|
||||
string Paraformer::ForwardChunk(float* din, int len, int flag)
|
||||
{
|
||||
|
||||
printf("Not Imp!!!!!!\n");
|
||||
return "Hello";
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
return "";
|
||||
}
|
||||
|
||||
string Paraformer::Rescoring()
|
||||
{
|
||||
printf("Not Imp!!!!!!\n");
|
||||
return "Hello";
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
return "";
|
||||
}
|
||||
|
||||
@ -2,10 +2,8 @@
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#ifndef PARAFORMER_MODELIMP_H
|
||||
#define PARAFORMER_MODELIMP_H
|
||||
|
||||
@ -23,9 +21,6 @@ namespace paraformer {
|
||||
//std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions fbank_opts;
|
||||
|
||||
std::unique_ptr<FsmnVad> vad_handle;
|
||||
std::unique_ptr<CTTransformer> punc_handle;
|
||||
|
||||
Vocab* vocab;
|
||||
vector<float> means_list;
|
||||
vector<float> vars_list;
|
||||
@ -36,7 +31,6 @@ namespace paraformer {
|
||||
void LoadCmvn(const char *filename);
|
||||
vector<float> ApplyLfr(const vector<float> &in);
|
||||
void ApplyCmvn(vector<float> *v);
|
||||
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums);
|
||||
|
||||
std::shared_ptr<Ort::Session> m_session;
|
||||
@ -46,22 +40,16 @@ namespace paraformer {
|
||||
vector<string> m_strInputNames, m_strOutputNames;
|
||||
vector<const char*> m_szInputNames;
|
||||
vector<const char*> m_szOutputNames;
|
||||
bool use_vad=false;
|
||||
bool use_punc=false;
|
||||
|
||||
public:
|
||||
Paraformer(std::map<std::string, std::string>& model_path, int thread_num=0);
|
||||
Paraformer();
|
||||
~Paraformer();
|
||||
void InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
|
||||
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
|
||||
void Reset();
|
||||
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
|
||||
string ForwardChunk(float* din, int len, int flag);
|
||||
string Forward(float* din, int len, int flag);
|
||||
string Rescoring();
|
||||
std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
|
||||
string AddPunc(const char* sz_input);
|
||||
bool UseVad(){return use_vad;};
|
||||
bool UsePunc(){return use_punc;};
|
||||
};
|
||||
|
||||
} // namespace paraformer
|
||||
|
||||
@ -30,6 +30,10 @@ using namespace std;
|
||||
#include "com-define.h"
|
||||
#include "commonfunc.h"
|
||||
#include "predefine-coe.h"
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
#include "punc-model.h"
|
||||
#include "offline-stream.h"
|
||||
#include "tokenizer.h"
|
||||
#include "ct-transformer.h"
|
||||
#include "fsmn-vad.h"
|
||||
@ -39,9 +43,8 @@ using namespace std;
|
||||
#include "tensor.h"
|
||||
#include "util.h"
|
||||
#include "resample.h"
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
#include "paraformer.h"
|
||||
#include "offline-stream.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
using namespace paraformer;
|
||||
|
||||
19
funasr/runtime/onnxruntime/src/punc-model.cpp
Normal file
19
funasr/runtime/onnxruntime/src/punc-model.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include "precomp.h"
|
||||
|
||||
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
PuncModel *mm;
|
||||
mm = new CTTransformer();
|
||||
|
||||
string punc_model_path;
|
||||
string punc_config_path;
|
||||
|
||||
punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
|
||||
|
||||
mm->InitPunc(punc_model_path, punc_config_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
@ -14,6 +14,10 @@ CTokenizer::CTokenizer():m_ready(false)
|
||||
{
|
||||
}
|
||||
|
||||
CTokenizer::~CTokenizer()
|
||||
{
|
||||
}
|
||||
|
||||
void CTokenizer::ReadYaml(const YAML::Node& node)
|
||||
{
|
||||
if (node.IsMap())
|
||||
|
||||
@ -17,6 +17,7 @@ public:
|
||||
|
||||
CTokenizer(const char* sz_yamlfile);
|
||||
CTokenizer();
|
||||
~CTokenizer();
|
||||
bool OpenYaml(const char* sz_yamlfile);
|
||||
void ReadYaml(const YAML::Node& node);
|
||||
vector<string> Id2String(vector<int> input);
|
||||
|
||||
21
funasr/runtime/onnxruntime/src/vad-model.cpp
Normal file
21
funasr/runtime/onnxruntime/src/vad-model.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
#include "precomp.h"
|
||||
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
VadModel *mm;
|
||||
mm = new FsmnVad();
|
||||
|
||||
string vad_model_path;
|
||||
string vad_cmvn_path;
|
||||
string vad_config_path;
|
||||
|
||||
vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), VAD_CMVN_NAME);
|
||||
vad_config_path = PathAppend(model_path.at(MODEL_DIR), VAD_CONFIG_NAME);
|
||||
|
||||
mm->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user