From 1480dcf5d571c4920b4f18717d580646794b8d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Thu, 10 Oct 2024 17:45:45 +0800 Subject: [PATCH] add GetInputNames GetOutputNames --- runtime/onnxruntime/src/commonfunc.h | 26 ++++++++ .../onnxruntime/src/ct-transformer-online.cpp | 19 +----- runtime/onnxruntime/src/ct-transformer.cpp | 15 +---- runtime/onnxruntime/src/fsmn-vad.cpp | 24 +------ runtime/onnxruntime/src/fsmn-vad.h | 5 +- runtime/onnxruntime/src/paraformer.cpp | 62 +++++++------------ runtime/onnxruntime/src/sensevoice-small.cpp | 22 +------ 7 files changed, 58 insertions(+), 115 deletions(-) diff --git a/runtime/onnxruntime/src/commonfunc.h b/runtime/onnxruntime/src/commonfunc.h index 3449ebca1..6fd553fe0 100644 --- a/runtime/onnxruntime/src/commonfunc.h +++ b/runtime/onnxruntime/src/commonfunc.h @@ -65,6 +65,19 @@ inline void GetInputName(Ort::Session* session, string& inputName,int nIndex=0) } } +inline void GetInputNames(Ort::Session* session, std::vector &m_strInputNames, + std::vector &m_szInputNames) { + Ort::AllocatorWithDefaultOptions allocator; + size_t numNodes = session->GetInputCount(); + m_strInputNames.resize(numNodes); + m_szInputNames.resize(numNodes); + for (size_t i = 0; i != numNodes; ++i) { + auto t = session->GetInputNameAllocated(i, allocator); + m_strInputNames[i] = t.get(); + m_szInputNames[i] = m_strInputNames[i].c_str(); + } +} + inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex = 0) { size_t numOutputNodes = session->GetOutputCount(); if (numOutputNodes > 0) { @@ -76,6 +89,19 @@ inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex } } +inline void GetOutputNames(Ort::Session* session, std::vector &m_strOutputNames, + std::vector &m_szOutputNames) { + Ort::AllocatorWithDefaultOptions allocator; + size_t numNodes = session->GetOutputCount(); + m_strOutputNames.resize(numNodes); + m_szOutputNames.resize(numNodes); + for (size_t i = 0; i != numNodes; ++i) { + auto t = session->GetOutputNameAllocated(i, allocator); + m_strOutputNames[i] = t.get(); + m_szOutputNames[i] = m_strOutputNames[i].c_str(); + } +} + template inline static size_t Argmax(ForwardIterator first, ForwardIterator last) { return std::distance(first, std::max_element(first, last)); diff --git a/runtime/onnxruntime/src/ct-transformer-online.cpp b/runtime/onnxruntime/src/ct-transformer-online.cpp index 92fe41e96..769bb6544 100644 --- a/runtime/onnxruntime/src/ct-transformer-online.cpp +++ b/runtime/onnxruntime/src/ct-transformer-online.cpp @@ -25,23 +25,8 @@ void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::str exit(-1); } // read inputnames outputnames - string strName; - GetInputName(m_session.get(), strName); - m_strInputNames.push_back(strName.c_str()); - GetInputName(m_session.get(), strName, 1); - m_strInputNames.push_back(strName); - GetInputName(m_session.get(), strName, 2); - m_strInputNames.push_back(strName); - GetInputName(m_session.get(), strName, 3); - m_strInputNames.push_back(strName); - - GetOutputName(m_session.get(), strName); - m_strOutputNames.push_back(strName); - - for (auto& item : m_strInputNames) - m_szInputNames.push_back(item.c_str()); - for (auto& item : m_strOutputNames) - m_szOutputNames.push_back(item.c_str()); + GetInputNames(m_session.get(), m_strInputNames, m_szInputNames); + GetOutputNames(m_session.get(), m_strOutputNames, m_szOutputNames); m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str()); } diff --git a/runtime/onnxruntime/src/ct-transformer.cpp b/runtime/onnxruntime/src/ct-transformer.cpp index d1a7813b1..2139aa955 100644 --- a/runtime/onnxruntime/src/ct-transformer.cpp +++ b/runtime/onnxruntime/src/ct-transformer.cpp @@ -25,20 +25,9 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p exit(-1); } // read inputnames outputnames - string strName; - GetInputName(m_session.get(), strName); - m_strInputNames.push_back(strName.c_str()); - GetInputName(m_session.get(), strName, 1); - m_strInputNames.push_back(strName); + GetInputNames(m_session.get(), m_strInputNames, m_szInputNames); + GetOutputNames(m_session.get(), m_strOutputNames, m_szOutputNames); - GetOutputName(m_session.get(), strName); - m_strOutputNames.push_back(strName); - - for (auto& item : m_strInputNames) - m_szInputNames.push_back(item.c_str()); - for (auto& item : m_strOutputNames) - m_szOutputNames.push_back(item.c_str()); - m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str()); m_tokenizer.JiebaInit(punc_config); } diff --git a/runtime/onnxruntime/src/fsmn-vad.cpp b/runtime/onnxruntime/src/fsmn-vad.cpp index b120c1939..bcef5d43d 100644 --- a/runtime/onnxruntime/src/fsmn-vad.cpp +++ b/runtime/onnxruntime/src/fsmn-vad.cpp @@ -60,30 +60,10 @@ void FsmnVad::ReadModel(const char* vad_model) { LOG(ERROR) << "Error when load vad onnx model: " << e.what(); exit(-1); } - GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_); + GetInputNames(vad_session_.get(), m_strInputNames, vad_in_names_); + GetOutputNames(vad_session_.get(), m_strOutputNames, vad_out_names_); } -void FsmnVad::GetInputOutputInfo( - const std::shared_ptr &session, - std::vector *in_names, std::vector *out_names) { - Ort::AllocatorWithDefaultOptions allocator; - // Input info - int num_nodes = session->GetInputCount(); - in_names->resize(num_nodes); - for (int i = 0; i < num_nodes; ++i) { - std::unique_ptr name = session->GetInputNameAllocated(i, allocator); - (*in_names)[i] = name.get(); - } - // Output info - num_nodes = session->GetOutputCount(); - out_names->resize(num_nodes); - for (int i = 0; i < num_nodes; ++i) { - std::unique_ptr name = session->GetOutputNameAllocated(i, allocator); - (*out_names)[i] = name.get(); - } -} - - void FsmnVad::Forward( const std::vector> &chunk_feats, std::vector> *out_prob, diff --git a/runtime/onnxruntime/src/fsmn-vad.h b/runtime/onnxruntime/src/fsmn-vad.h index f06a9651b..dc4726a90 100644 --- a/runtime/onnxruntime/src/fsmn-vad.h +++ b/runtime/onnxruntime/src/fsmn-vad.h @@ -34,6 +34,7 @@ public: std::shared_ptr vad_session_ = nullptr; Ort::Env env_; Ort::SessionOptions session_options_; + vector m_strInputNames, m_strOutputNames; std::vector vad_in_names_; std::vector vad_out_names_; std::vector> in_cache_; @@ -54,10 +55,6 @@ private: void ReadModel(const char* vad_model); void LoadConfigFromYaml(const char* filename); - static void GetInputOutputInfo( - const std::shared_ptr &session, - std::vector *in_names, std::vector *out_names); - void FbankKaldi(float sample_rate, std::vector> &vad_feats, std::vector &waves); diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp index 1f1d48f55..24f5152e2 100644 --- a/runtime/onnxruntime/src/paraformer.cpp +++ b/runtime/onnxruntime/src/paraformer.cpp @@ -45,26 +45,8 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn exit(-1); } - string strName; - GetInputName(m_session_.get(), strName); - m_strInputNames.push_back(strName.c_str()); - GetInputName(m_session_.get(), strName,1); - m_strInputNames.push_back(strName); - if (use_hotword) { - GetInputName(m_session_.get(), strName, 2); - m_strInputNames.push_back(strName); - } - - size_t numOutputNodes = m_session_->GetOutputCount(); - for(int index=0; indexGetOutputCount(); - for(int index=0; indexGetOutputCount(); + // for(int index=0; indexGetOutputCount(); - for(int index=0; index