mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
modify paraformer onnx init
This commit is contained in:
parent
ddfcd68c80
commit
58c119c508
@ -4,7 +4,7 @@ using namespace std;
|
||||
using namespace paraformer;
|
||||
|
||||
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
{
|
||||
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
|
||||
string model_path;
|
||||
string cmvn_path;
|
||||
string config_path;
|
||||
@ -29,20 +29,20 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
|
||||
#ifdef _WIN32
|
||||
wstring wstrPath = strToWstr(model_path);
|
||||
m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
|
||||
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
|
||||
#else
|
||||
m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
|
||||
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
|
||||
#endif
|
||||
|
||||
string strName;
|
||||
getInputName(m_session, strName);
|
||||
getInputName(m_session.get(), strName);
|
||||
m_strInputNames.push_back(strName.c_str());
|
||||
getInputName(m_session, strName,1);
|
||||
getInputName(m_session.get(), strName,1);
|
||||
m_strInputNames.push_back(strName);
|
||||
|
||||
getOutputName(m_session, strName);
|
||||
getOutputName(m_session.get(), strName);
|
||||
m_strOutputNames.push_back(strName);
|
||||
getOutputName(m_session, strName,1);
|
||||
getOutputName(m_session.get(), strName,1);
|
||||
m_strOutputNames.push_back(strName);
|
||||
|
||||
for (auto& item : m_strInputNames)
|
||||
@ -55,11 +55,6 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
|
||||
ModelImp::~ModelImp()
|
||||
{
|
||||
if (m_session)
|
||||
{
|
||||
delete m_session;
|
||||
m_session = nullptr;
|
||||
}
|
||||
if(vocab)
|
||||
delete vocab;
|
||||
fftwf_free(fft_input);
|
||||
@ -172,6 +167,12 @@ string ModelImp::forward(float* din, int len, int flag)
|
||||
apply_cmvn(in);
|
||||
Ort::RunOptions run_option;
|
||||
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
|
||||
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
|
||||
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
|
||||
in->buff,
|
||||
|
||||
@ -24,15 +24,9 @@ namespace paraformer {
|
||||
|
||||
string greedy_search( float* in, int nLen);
|
||||
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
|
||||
Ort::Session* m_session = nullptr;
|
||||
Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
|
||||
Ort::SessionOptions sessionOptions = Ort::SessionOptions();
|
||||
std::unique_ptr<Ort::Session> m_session;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sessionOptions;
|
||||
|
||||
vector<string> m_strInputNames, m_strOutputNames;
|
||||
vector<const char*> m_szInputNames;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user