cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments
This commit is contained in:
zhaomingwork 2024-05-23 17:34:52 +08:00 committed by GitHub
parent f47d43c020
commit 4b388768d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 2249 additions and 0 deletions

142
runtime/http/CMakeLists.txt Normal file
View File

@ -0,0 +1,142 @@
cmake_minimum_required(VERSION 3.16)
project(FunASRWebscoket)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_HTTP "Whether to build http server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
if(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/export.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/raw_logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/stl_logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/vlog_is_on.h)
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(ENABLE_HTTP)
# cmake_policy(SET CMP0135 NEW)
include(FetchContent)
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio/asio )
FetchContent_Declare(asio
URL https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
)
FetchContent_MakeAvailable(asio)
endif()
include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
FetchContent_Declare(json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
)
FetchContent_MakeAvailable(json)
endif()
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
endif()
if(ENABLE_PORTAUDIO)
include(FetchContent)
set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
FetchContent_Declare(portaudio
URL
${portaudio_URL}
${portaudio_URL2}
URL_HASH ${portaudio_HASH}
)
FetchContent_GetProperties(portaudio)
if(NOT portaudio_POPULATED)
message(STATUS "Downloading portaudio from ${portaudio_URL}")
FetchContent_Populate(portaudio)
endif()
message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
if(NOT WIN32)
target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
else()
install(TARGETS portaudio DESTINATION ..)
endif()
endif()
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src)
set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
include_directories(${glog_BINARY_DIR})
endif()
if(ENABLE_FST)
# fst depend on glog and gflags
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags gflags)
include_directories(${gflags_BINARY_DIR}/include)
# the following openfst if cloned from https://github.com/kkm000/openfst.git
# with some patch to fix the make errors.
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/openfst openfst)
include_directories(${openfst_SOURCE_DIR}/src/include)
if(WIN32)
include_directories(${openfst_SOURCE_DIR}/src/lib)
endif()
endif()
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/src)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include/limonp/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi kaldi)
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
#
get_property(includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
#
foreach(include ${includes})
message("Include directory: ${include}")
endforeach()
add_subdirectory(bin)

View File

@ -0,0 +1,23 @@
if(WIN32)
include_directories(${ONNXRUNTIME_DIR}/include)
include_directories(${FFMPEG_DIR}/include)
include_directories(${OPENSSL_ROOT_DIR}//include)
link_directories(${OPENSSL_ROOT_DIR}/lib)
add_definitions(-D_WEBSOCKETPP_CPP11_RANDOM_DEVICE_)
add_definitions(-D_WEBSOCKETPP_CPP11_TYPE_TRAITS_)
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/bigobj>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
find_package(ZLIB REQUIRED)
file(GLOB SRC_FILES "*.cpp")
add_executable(funasr-http-server ${SRC_FILES} ${RELATION_SOURCE})
target_link_libraries(funasr-http-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})

View File

@ -0,0 +1,20 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// FUNASR_MESSAGE define the needed message between funasr engine and http server
#ifndef HTTP_SERVER2_SESSIONS_HPP
#define HTTP_SERVER2_SESSIONS_HPP
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include <atomic>
typedef struct {
nlohmann::json msg;
std::shared_ptr<std::vector<char>> samples;
std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=nullptr;
FUNASR_DEC_HANDLE decoder_handle=nullptr;
std::atomic<int> status;
} FUNASR_MESSAGE;
#endif // HTTP_SERVER2_REQUEST_PARSER_HPP

View File

@ -0,0 +1,196 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// connection.cpp
// copy some codes from http://www.boost.org/
#include "connection.hpp"
#include <thread>
#include <utility>
namespace http {
namespace server2 {
//std::ofstream fwout("out.data", std::ios::binary);
std::shared_ptr<FUNASR_MESSAGE> &connection::get_data_msg() { return data_msg; }
connection::connection(asio::ip::tcp::socket socket,
asio::io_context &io_decoder, int connection_id,
std::shared_ptr<ModelDecoder> model_decoder)
: socket_(std::move(socket)),
io_decoder(io_decoder),
connection_id(connection_id),
model_decoder(model_decoder)
{
s_timer = std::make_shared<asio::steady_timer>(io_decoder);
}
void connection::setup_timer() {
if (data_msg->status == 1) return;
s_timer->expires_after(std::chrono::seconds(3));
s_timer->async_wait([=](const asio::error_code &ec) {
if (!ec) {
std::cout << "time is out!" << std::endl;
if (data_msg->status == 1) return;
data_msg->status = 1;
s_timer->cancel();
auto wf = std::bind(&connection::write_back, std::ref(*this), "");
// close the connection
strand_->post(wf);
}
});
}
void connection::start() {
std::lock_guard<std::mutex> lock(m_lock); // for threads safty
try {
data_msg = std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for
// new connection
data_msg->samples = std::make_shared<std::vector<char>>();
//data_msg->samples->reserve(16000*20);
data_msg->msg = nlohmann::json::parse("{}");
data_msg->msg["wav_format"] = "pcm";
data_msg->msg["wav_name"] = "wav-default-id";
data_msg->msg["itn"] = true;
data_msg->msg["audio_fs"] = 16000; // default is 16k
data_msg->msg["access_num"] = 0; // the number of access for this object,
// when it is 0, we can free it saftly
data_msg->msg["is_eof"] = false;
data_msg->status = 0;
strand_ = std::make_shared<asio::io_context::strand>(io_decoder);
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(
model_decoder->get_asr_handle(), ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
data_msg->decoder_handle = decoder_handle;
if (data_msg->hotwords_embedding == nullptr) {
std::unordered_map<std::string, int> merged_hws_map;
std::string nn_hotwords = "";
if (true) {
std::string json_string = "{}";
if (!json_string.empty()) {
nlohmann::json json_fst_hws;
try {
json_fst_hws = nlohmann::json::parse(json_string);
if (json_fst_hws.type() == nlohmann::json::value_t::object) {
// fst
try {
std::unordered_map<std::string, int> client_hws_map =
json_fst_hws;
merged_hws_map.insert(client_hws_map.begin(),
client_hws_map.end());
} catch (const std::exception &e) {
std::cout << e.what();
}
}
} catch (std::exception const &e) {
std::cout << e.what();
// nn
std::string client_nn_hws = "{}";
nn_hotwords += " " + client_nn_hws;
std::cout << "nn hotwords: " << client_nn_hws;
}
}
}
merged_hws_map.insert(hws_map_.begin(), hws_map_.end());
// fst
std::cout << "hotwords: ";
for (const auto &pair : merged_hws_map) {
nn_hotwords += " " + pair.first;
std::cout << pair.first << " : " << pair.second;
}
FunWfstDecoderLoadHwsRes(data_msg->decoder_handle, fst_inc_wts_,
merged_hws_map);
// nn
std::vector<std::vector<float>> new_hotwords_embedding =
CompileHotwordEmbedding(model_decoder->get_asr_handle(), nn_hotwords);
data_msg->hotwords_embedding =
std::make_shared<std::vector<std::vector<float>>>(
new_hotwords_embedding);
}
file_parse = std::make_shared<http::server2::file_parser>(data_msg);
do_read();
} catch (const std::exception &e) {
std::cout << "error:" << e.what();
}
}
void connection::write_back(std::string str) {
s_timer->cancel();
std::cout << "jsonresult=" << data_msg->msg["asr_result"].dump() << std::endl;
reply_ = reply::stock_reply(
data_msg->msg["asr_result"].dump()); // reply::stock_reply();
do_write();
}
void connection::do_read() {
// status==1 means time out
if (data_msg->status == 1) return;
s_timer->cancel();
setup_timer();
auto self(shared_from_this());
socket_.async_read_some(
asio::buffer(buffer_),
[this, self](asio::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
auto is = std::begin(buffer_);
auto ie = std::next(is, bytes_transferred);
http::server2::file_parser::result_type rtype =
file_parse->parse_file(is, ie);
if (rtype == http::server2::file_parser::result_type::ok) {
//fwout.write(data_msg->samples->data(),data_msg->samples->size());
//fwout.flush();
auto wf = std::bind(&connection::write_back, std::ref(*this), "aa");
auto f = std::bind(&ModelDecoder::do_decoder,
std::ref(*model_decoder), std::ref(data_msg));
// for decode task
strand_->post(f);
// for close task
strand_->post(wf);
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
}
do_read();
}
});
}
void connection::do_write() {
auto self(shared_from_this());
asio::async_write(socket_, reply_.to_buffers(),
[this, self](asio::error_code ec, std::size_t) {
if (!ec) {
// Initiate graceful connection closure.
asio::error_code ignored_ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_both,
ignored_ec);
}
// No new asynchronous operations are started. This means
// that all shared_ptr references to the connection object
// will disappear and the object will be destroyed
// automatically after this handler returns. The
// connection class's destructor closes the socket.
});
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,104 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// copy some codes from http://www.boost.org/
//
#ifndef HTTP_SERVER2_CONNECTION_HPP
#define HTTP_SERVER2_CONNECTION_HPP
#include <array>
#include <asio.hpp>
#include <atomic>
#include <iostream>
#include <memory>
#include "reply.hpp"
#include <fstream>
#include "file_parse.hpp"
#include "model-decoder.h"
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
namespace http {
namespace server2 {
/// Represents a single connection from a client.
class connection : public std::enable_shared_from_this<connection> {
public:
connection(const connection &) = delete;
connection &operator=(const connection &) = delete;
~connection() { std::cout << "one connection is close()" << std::endl; };
/// Construct a connection with the given socket.
explicit connection(asio::ip::tcp::socket socket,
asio::io_context &io_decoder, int connection_id,
std::shared_ptr<ModelDecoder> model_decoder);
/// Start the first asynchronous operation for the connection.
void start();
std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
void write_back(std::string str);
private:
/// Perform an asynchronous read operation.
void do_read();
/// Perform an asynchronous write operation.
void do_write();
void do_decoder();
void setup_timer();
/// Socket for the connection.
asio::ip::tcp::socket socket_;
/// Buffer for incoming data.
std::array<char, 8192> buffer_;
/// for time out
std::shared_ptr<asio::steady_timer> s_timer;
std::shared_ptr<ModelDecoder> model_decoder;
int connection_id = 0;
/// The reply to be sent back to the client.
reply reply_;
asio::io_context &io_decoder;
std::shared_ptr<FUNASR_MESSAGE> data_msg;
std::mutex m_lock;
std::shared_ptr<asio::io_context::strand> strand_;
std::shared_ptr<http::server2::file_parser> file_parse;
};
typedef std::shared_ptr<connection> connection_ptr;
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_CONNECTION_HPP

View File

@ -0,0 +1,29 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#include "file_parse.hpp"
namespace http {
namespace server2 {
file_parser::file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg)
:data_msg(data_msg)
{
now_state=start;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,234 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// ~~~~~~~~~~~~~~~~~~
#ifndef HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#define HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#include <iostream>
#include <memory>
#include <tuple>
#include "asr_sessions.h"
namespace http {
namespace server2 {
/// Parser for incoming requests.
class file_parser {
public:
/// Construct ready to parse the request method.
explicit file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg);
/// Result of parse.
enum result_type { start, in_boundary, data, ok };
template <typename InputIterator>
void parse_one_line(InputIterator &is, InputIterator &ie, InputIterator &it) {
if (is != it) {
is = it;
}
if (*it == '\n') {
is = std::next(is);
}
it = std::find(is, ie, '\n');
std::string str(is, it);
}
std::string trim_name(std::string raw_string) {
int pos = raw_string.find('\"');
if (pos != std::string::npos) {
raw_string = raw_string.substr(pos + 1);
pos = raw_string.find('\"');
raw_string = raw_string.substr(0, pos);
}
return raw_string;
}
std::string parese_file_ext(std::string file_name) {
int pos = file_name.rfind('.');
std::string ext = "";
if (pos != std::string::npos) ext = file_name.substr(pos + 1);
return ext;
}
template <typename InputIterator>
int parse_data_content(InputIterator is, InputIterator ie, InputIterator it) {
int len = std::distance(it + 1, ie);
if (len <= 0) {
return 0;
}
std::string str(it + 1, ie);
// check if at the end, "--boundary--" need +4 for "--"
if (len == boundary.length() + 4)
{
std::string str(it + 1, ie);
// std::cout << "len good=" << str << std::endl;
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\n') {
// remove '\n' in boundary
boundary = boundary.substr(0, boundary.length() - 2);
}
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\r') {
// remove '\r' in boundary
boundary = boundary.substr(0, boundary.length() - 2);
}
auto found_boundary = str.find(boundary);
if (found_boundary == std::string::npos) {
std::cout << "not found end boundary!=" << found_boundary << std::endl;
return 0;
}
// remove the end of data that contains '\n' or '\r'
int last_sub = 0;
if (*(it) == '\n') {
last_sub++;
}
int lasts_len = std::distance(it, ie);
data_msg->samples->erase(data_msg->samples->end() - last_sub - lasts_len,
data_msg->samples->end());
std::cout << "one file finished, file size=" << data_msg->samples->size()
<< std::endl;
return 1;
}
}
template <typename InputIterator>
void parse_boundary_content(InputIterator is, InputIterator ie,
InputIterator it) {
parse_one_line(is, ie, it);
std::string str;
while (it != ie) {
str = std::string(is, it);
auto found_content = str.find("Content-Disposition:");
auto found_filename = str.find("filename=");
if (found_content != std::string::npos &&
found_filename != std::string::npos) {
std::string file_name =
str.substr(found_filename + 9, std::string::npos);
file_name = trim_name(file_name);
std::string ext = parese_file_ext(file_name);
if (file_name.find(".wav") != std::string::npos) {
std::cout << "set wav_format=pcm, file_name=" << file_name
<< std::endl;
data_msg->msg["wav_format"] = "pcm";
} else {
std::cout << "set wav_format=" << ext << ", file_name=" << file_name
<< std::endl;
data_msg->msg["wav_format"] = ext;
}
data_msg->msg["wav_name"] = file_name;
now_state = data;
} else {
auto found_content = str.find("Content-Disposition:");
auto found_name = str.find("name=");
if (found_content != std::string::npos &&
found_name != std::string::npos) {
std::string name = str.substr(found_name + 5, std::string::npos);
name = trim_name(name);
parse_one_line(is, ie, it);
if (*it == '\n') it++;
parse_one_line(is, ie, it);
str = std::string(is, it);
std::cout << "para: name=" << name << ",value=" << str << std::endl;
}
}
parse_one_line(is, ie, it);
if (now_state == data && std::distance(is, it) <= 2) {
break;
}
}
if (now_state == data) {
if (*it == '\n') it++;
data_msg->samples->insert(data_msg->samples->end(), it,
it + std::distance(it, ie));
// it=ie;
}
}
template <typename InputIterator>
result_type parse_file(InputIterator is, InputIterator ie) {
if (now_state == data) {
data_msg->samples->insert(data_msg->samples->end(), is, ie);
}
auto it = is;
while (it != ie) {
std::string str(is, it);
parse_one_line(is, ie, it);
if (now_state == data) {
// for data end search
int ret = parse_data_content(is, ie, it);
if (ret == 0) continue;
return ok;
} else {
std::string str(is, it + 1);
if (now_state == start) {
auto found_boundary = str.find("Content-Length:");
if (found_boundary != std::string::npos) {
std::string file_len =
str.substr(found_boundary + 15, std::string::npos);
data_msg->samples->reserve(std::stoi(file_len));
}
found_boundary = str.find("boundary=");
if (found_boundary != std::string::npos) {
boundary = str.substr(found_boundary + 9, std::string::npos);
now_state = in_boundary;
}
} else if (now_state == in_boundary) {
// for file header
auto found_boundary = str.find(boundary);
if (found_boundary != std::string::npos) {
parse_boundary_content(is, ie, it);
}
}
}
}
return now_state;
}
private:
std::shared_ptr<FUNASR_MESSAGE> data_msg;
result_type now_state;
std::string boundary = "";
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_REQUEST_FILEPARSER_HPP

View File

@ -0,0 +1,523 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#include "funasr-http-main.hpp"
#ifdef _WIN32
#include "win_func.h"
#else
#include <unistd.h>
#endif
#include <fstream>
#include "util.h"
// hotwords
std::unordered_map<std::string, int> hws_map_;
int fst_inc_wts_ = 20;
float global_beam_, lattice_beam_, am_scale_;
using namespace std;
void GetValue(TCLAP::ValueArg<std::string> &value_arg, string key,
std::map<std::string, std::string> &model_path) {
model_path.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
}
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path,
int thread_num) {
try {
// init model with api
FUNASR_HANDLE asr_handle = FunOfflineInit(model_path, thread_num);
LOG(INFO) << "model successfully inited";
LOG(INFO) << "initAsr run check_and_clean_connection";
// std::thread
// clean_thread(&ModelDecoderSrv::check_and_clean_connection,this);
// clean_thread.detach();
LOG(INFO) << "initAsr run check_and_clean_connection finished";
return asr_handle;
} catch (const std::exception &e) {
LOG(INFO) << e.what();
// return nullptr;
}
}
int main(int argc, char *argv[]) {
#ifdef _WIN32
#include <windows.h>
SetConsoleOutputCP(65001);
#endif
try {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
std::string offline_version = "";
#ifdef _WIN32
offline_version = "0.1.0";
#endif
TCLAP::CmdLine cmd("funasr-wss-server", ' ', offline_version);
TCLAP::ValueArg<std::string> download_model_dir(
"", "download-model-dir",
"Download model from Modelscope to download_model_dir", false,
"/workspace/models", "string");
TCLAP::ValueArg<std::string> model_dir(
"", OFFLINE_MODEL_DIR,
"default: "
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx, "
"the asr model path, which "
"contains model_quant.onnx, config.yaml, am.mvn",
false,
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx",
"string");
TCLAP::ValueArg<std::string> model_revision("", "offline-model-revision",
"ASR offline model revision",
false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
"true (Default), load the model of model_quant.onnx in model_dir. If "
"set "
"false, load the model of model.onnx in model_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
"default: damo/speech_fsmn_vad_zh-cn-16k-common-onnx, the vad model "
"path, which contains "
"model_quant.onnx, vad.yaml, vad.mvn",
false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
"false, load the model of model.onnx in vad_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir(
"", PUNC_DIR,
"default: "
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx, "
"the punc model path, which contains "
"model_quant.onnx, punc.yaml",
false,
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx",
"string");
TCLAP::ValueArg<std::string> punc_revision(
"", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
"true (Default), load the model of model_quant.onnx in punc_dir. If "
"set "
"false, load the model of model.onnx in punc_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir(
"", ITN_DIR,
"default: thuduj12/fst_itn_zh, the itn model path, which contains "
"zh_itn_tagger.fst, zh_itn_verbalizer.fst",
false, "", "string");
TCLAP::ValueArg<std::string> itn_revision(
"", "itn-revision", "ITN model revision", false, "v1.0.1", "string");
TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
"0.0.0.0", "string");
TCLAP::ValueArg<int> port("", "port", "port", false, 80, "int");
TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
false, 8, "int");
TCLAP::ValueArg<int> decoder_thread_num(
"", "decoder-thread-num", "decoder thread num", false, 32, "int");
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
"model thread num", false, 1, "int");
TCLAP::ValueArg<std::string> certfile(
"", "certfile",
"default: ../../../ssl_key/server.crt, path of certficate for WSS "
"connection. if it is empty, it will be in WS mode.",
false, "../../../ssl_key/server.crt", "string");
TCLAP::ValueArg<std::string> keyfile(
"", "keyfile",
"default: ../../../ssl_key/server.key, path of keyfile for WSS "
"connection",
false, "../../../ssl_key/server.key", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM,
"the decoding beam for beam searching ",
false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam(
"", LAT_BEAM, "the lattice generation beam for beam searching ", false,
3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE,
"the acoustic scale for beam searching ",
false, 10.0, "float");
TCLAP::ValueArg<std::string> lm_dir(
"", LM_DIR,
"the LM model path, which contains compiled models: TLG.fst, "
"config.yaml ",
false, "", "string");
TCLAP::ValueArg<std::string> lm_revision(
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
TCLAP::ValueArg<std::string> hotword(
"", HOTWORD,
"the hotword file, one hotword perline, Format: Hotword Weight (could "
"be: 阿里巴巴 20)",
false, "/workspace/resources/hotwords.txt", "string");
TCLAP::ValueArg<std::int32_t> fst_inc_wts(
"", FST_INC_WTS, "the fst hotwords incremental bias", false, 20,
"int32_t");
// add file
cmd.add(hotword);
cmd.add(fst_inc_wts);
cmd.add(global_beam);
cmd.add(lattice_beam);
cmd.add(am_scale);
cmd.add(certfile);
cmd.add(keyfile);
cmd.add(download_model_dir);
cmd.add(model_dir);
cmd.add(model_revision);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_revision);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_revision);
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(itn_revision);
cmd.add(lm_dir);
cmd.add(lm_revision);
cmd.add(listen_ip);
cmd.add(port);
cmd.add(io_thread_num);
cmd.add(decoder_thread_num);
cmd.add(model_thread_num);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(hotword, HOTWORD, model_path);
GetValue(model_revision, "model-revision", model_path);
GetValue(vad_revision, "vad-revision", model_path);
GetValue(punc_revision, "punc-revision", model_path);
GetValue(itn_revision, "itn-revision", model_path);
GetValue(lm_revision, "lm-revision", model_path);
global_beam_ = global_beam.getValue();
lattice_beam_ = lattice_beam.getValue();
am_scale_ = am_scale.getValue();
// Download model form Modelscope
try {
std::string s_download_model_dir = download_model_dir.getValue();
std::string s_vad_path = model_path[VAD_DIR];
std::string s_vad_quant = model_path[VAD_QUANT];
std::string s_asr_path = model_path[MODEL_DIR];
std::string s_asr_quant = model_path[QUANTIZE];
std::string s_punc_path = model_path[PUNC_DIR];
std::string s_punc_quant = model_path[PUNC_QUANT];
std::string s_itn_path = model_path[ITN_DIR];
std::string s_lm_path = model_path[LM_DIR];
std::string python_cmd =
"python -m funasr.download.runtime_sdk_download_tool --type onnx "
"--quantize True ";
if (vad_dir.isSet() && !s_vad_path.empty()) {
std::string python_cmd_vad;
std::string down_vad_path;
std::string down_vad_model;
if (access(s_vad_path.c_str(), F_OK) == 0) {
// local
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
" --export-dir ./ " + " --model_revision " +
model_path["vad-revision"];
down_vad_path = s_vad_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["vad-revision"];
down_vad_path = s_download_model_dir + "/" + s_vad_path;
}
int ret = system(python_cmd_vad.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local vad model path, you can ignore the errors.";
}
down_vad_model = down_vad_path + "/model_quant.onnx";
if (s_vad_quant == "false" || s_vad_quant == "False" ||
s_vad_quant == "FALSE") {
down_vad_model = down_vad_path + "/model.onnx";
}
if (access(down_vad_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_vad_model << " do not exists.";
exit(-1);
} else {
model_path[VAD_DIR] = down_vad_path;
LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
}
} else {
LOG(INFO) << "VAD model is not set, use default.";
}
if (model_dir.isSet() && !s_asr_path.empty()) {
std::string python_cmd_asr;
std::string down_asr_path;
std::string down_asr_model;
// modify model-revision by model name
size_t found = s_asr_path.find(
"speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-"
"vocab8404");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.2.4";
}
found = s_asr_path.find(
"speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-"
"vocab8404");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.0.5";
}
found = s_asr_path.find(
"speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.0.0";
s_itn_path = "";
s_lm_path = "";
}
if (access(s_asr_path.c_str(), F_OK) == 0) {
// local
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
" --export-dir ./ " + " --model_revision " +
model_path["model-revision"];
down_asr_path = s_asr_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["model-revision"];
down_asr_path = s_download_model_dir + "/" + s_asr_path;
}
int ret = system(python_cmd_asr.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local asr model path, you can ignore the errors.";
}
down_asr_model = down_asr_path + "/model_quant.onnx";
if (s_asr_quant == "false" || s_asr_quant == "False" ||
s_asr_quant == "FALSE") {
down_asr_model = down_asr_path + "/model.onnx";
}
if (access(down_asr_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_asr_model << " do not exists.";
exit(-1);
} else {
model_path[MODEL_DIR] = down_asr_path;
LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
}
} else {
LOG(INFO) << "ASR model is not set, use default.";
}
if (!s_itn_path.empty()) {
std::string python_cmd_itn;
std::string down_itn_path;
std::string down_itn_model;
if (access(s_itn_path.c_str(), F_OK) == 0) {
// local
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
" --export-dir ./ " + " --model_revision " +
model_path["itn-revision"] + " --export False ";
down_itn_path = s_itn_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_itn_path
<< " from modelscope : ";
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["itn-revision"] +
" --export False ";
down_itn_path = s_download_model_dir + "/" + s_itn_path;
}
int ret = system(python_cmd_itn.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local itn model path, you can ignore the errors.";
}
down_itn_model = down_itn_path + "/zh_itn_tagger.fst";
if (access(down_itn_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_itn_model << " do not exists.";
exit(-1);
} else {
model_path[ITN_DIR] = down_itn_path;
LOG(INFO) << "Set " << ITN_DIR << " : " << model_path[ITN_DIR];
}
} else {
LOG(INFO) << "ITN model is not set, not executed.";
}
if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
std::string python_cmd_lm;
std::string down_lm_path;
std::string down_lm_model;
if (access(s_lm_path.c_str(), F_OK) == 0) {
// local
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
" --export-dir ./ " + " --model_revision " +
model_path["lm-revision"] + " --export False ";
down_lm_path = s_lm_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_lm_path << " from modelscope : ";
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["lm-revision"] +
" --export False ";
down_lm_path = s_download_model_dir + "/" + s_lm_path;
}
int ret = system(python_cmd_lm.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local lm model path, you can ignore the errors.";
}
down_lm_model = down_lm_path + "/TLG.fst";
if (access(down_lm_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_lm_model << " do not exists.";
exit(-1);
} else {
model_path[LM_DIR] = down_lm_path;
LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
}
} else {
LOG(INFO) << "LM model is not set, not executed.";
model_path[LM_DIR] = "";
}
if (punc_dir.isSet() && !s_punc_path.empty()) {
std::string python_cmd_punc;
std::string down_punc_path;
std::string down_punc_model;
if (access(s_punc_path.c_str(), F_OK) == 0) {
// local
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
" --export-dir ./ " + " --model_revision " +
model_path["punc-revision"];
down_punc_path = s_punc_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_punc_path
<< " from modelscope: ";
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["punc-revision"];
down_punc_path = s_download_model_dir + "/" + s_punc_path;
}
int ret = system(python_cmd_punc.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local punc model path, you can ignore the errors.";
}
down_punc_model = down_punc_path + "/model_quant.onnx";
if (s_punc_quant == "false" || s_punc_quant == "False" ||
s_punc_quant == "FALSE") {
down_punc_model = down_punc_path + "/model.onnx";
}
if (access(down_punc_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_punc_model << " do not exists.";
exit(-1);
} else {
model_path[PUNC_DIR] = down_punc_path;
LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
}
} else {
LOG(INFO) << "PUNC model is not set, use default.";
}
} catch (std::exception const &e) {
LOG(ERROR) << "Error: " << e.what();
}
std::string s_listen_ip = listen_ip.getValue();
int s_port = port.getValue();
int s_io_thread_num = io_thread_num.getValue();
int s_decoder_thread_num = decoder_thread_num.getValue();
int s_model_thread_num = model_thread_num.getValue();
asio::io_context io_decoder; // context for decoding
std::vector<std::thread> decoder_threads;
// hotword file
std::string hotword_path;
hotword_path = model_path.at(HOTWORD);
fst_inc_wts_ = fst_inc_wts.getValue();
LOG(INFO) << "hotword path: " << hotword_path;
funasr::ExtractHws(hotword_path, hws_map_);
auto conn_guard = asio::make_work_guard(
io_decoder); // make sure threads can wait in the queue
// create threads pool
for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
}
// ModelDecoderSrv modelSrv(
// io_decoder); // websocket server for asr engine
// modelSrv.initAsr(model_path, s_model_thread_num); // init asr model
// FUNASR_HANDLE asr_handle= initAsr();
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
LOG(INFO) << "io-thread-num: " << s_io_thread_num;
LOG(INFO) << "model-thread-num: " << s_model_thread_num;
http::server2::server s(s_listen_ip, std::to_string(s_port), "./",
s_io_thread_num, io_decoder, model_path,
s_model_thread_num);
s.run();
LOG(INFO) << "http model loop " << s_port;
// wait for theads
for (auto &t : decoder_threads) {
t.join();
}
} catch (std::exception const &e) {
LOG(ERROR) << "Error: " << e.what();
}
return 0;
}

View File

@ -0,0 +1,20 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#ifndef HTTP_SERVER2_MAIN_HPP
#define HTTP_SERVER2_MAIN_HPP
#include "model-decoder.h"
#include "server.hpp"
namespace http {
namespace server2 {
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_MAIN_HPP

View File

@ -0,0 +1,27 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// header.hpp
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_HEADER_HPP
#define HTTP_SERVER2_HEADER_HPP
#include <string>
namespace http {
namespace server2 {
struct header
{
std::string name;
std::string value;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_HEADER_HPP

View File

@ -0,0 +1,66 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.cpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from http://www.boost.org/
#include "io_context_pool.hpp"
#include <stdexcept>
#include <thread>
namespace http {
namespace server2 {
io_context_pool::io_context_pool(std::size_t pool_size)
: next_io_context_(0)
{
if (pool_size == 0)
throw std::runtime_error("io_context_pool size is 0");
// Give all the io_contexts work to do so that their run() functions will not
// exit until they are explicitly stopped.
for (std::size_t i = 0; i < pool_size; ++i)
{
io_context_ptr io_context(new asio::io_context);
io_contexts_.push_back(io_context);
work_.push_back(asio::make_work_guard(*io_context));
}
}
void io_context_pool::run()
{
// Create a pool of threads to run all of the io_contexts.
std::vector<std::thread> threads;
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
threads.emplace_back([this, i]{ io_contexts_[i]->run(); });
// Wait for all threads in the pool to exit.
for (std::size_t i = 0; i < threads.size(); ++i)
threads[i].join();
}
void io_context_pool::stop()
{
// Explicitly stop all io_contexts.
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
io_contexts_[i]->stop();
}
asio::io_context& io_context_pool::get_io_context()
{
// Use a round-robin scheme to choose the next io_context to use.
asio::io_context& io_context = *io_contexts_[next_io_context_];
++next_io_context_;
if (next_io_context_ == io_contexts_.size())
next_io_context_ = 0;
return io_context;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,59 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.hpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_IO_SERVICE_POOL_HPP
#define HTTP_SERVER2_IO_SERVICE_POOL_HPP
#include <asio.hpp>
#include <list>
#include <memory>
#include <vector>
namespace http {
namespace server2 {
/// A pool of io_context objects.
class io_context_pool
{
public:
/// Construct the io_context pool.
explicit io_context_pool(std::size_t pool_size);
/// Run all io_context objects in the pool.
void run();
/// Stop all io_context objects in the pool.
void stop();
/// Get an io_context to use.
asio::io_context& get_io_context();
private:
io_context_pool(const io_context_pool&) = delete;
io_context_pool& operator=(const io_context_pool&) = delete;
typedef std::shared_ptr<::asio::io_context> io_context_ptr;
typedef asio::executor_work_guard<
asio::io_context::executor_type> io_context_work;
/// The pool of io_contexts.
std::vector<io_context_ptr> io_contexts_;
/// The work that keeps the io_contexts running.
std::list<io_context_work> work_;
/// The next io_context to use for a connection.
std::size_t next_io_context_;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_IO_SERVICE_POOL_HPP

View File

@ -0,0 +1,119 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#include "model-decoder.h"
#include <thread>
#include <utility>
#include <vector>
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
// feed msg to asr engine for decoder
void ModelDecoder::do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg) {
try {
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
if (session_msg->status == 1) return;
//std::cout << "in do_decoder" << std::endl;
std::shared_ptr<std::vector<char>> buffer = session_msg->samples;
int num_samples = buffer->size(); // the size of the buf
std::string wav_name =session_msg->msg["wav_name"];
bool itn = session_msg->msg["itn"];
int audio_fs = session_msg->msg["audio_fs"];;
std::string wav_format = session_msg->msg["wav_format"];
if (num_samples > 0 && session_msg->hotwords_embedding->size() > 0) {
std::string asr_result = "";
std::string stamp_res = "";
std::string stamp_sents = "";
try {
std::vector<std::vector<float>> hotwords_embedding_(
*(session_msg->hotwords_embedding));
FUNASR_RESULT Result = FunOfflineInferBuffer(
asr_handle, buffer->data(), buffer->size(), RASR_NONE, nullptr,
std::move(hotwords_embedding_), audio_fs, wav_format, itn,
session_msg->decoder_handle);
if (Result != nullptr) {
asr_result = FunASRGetResult(Result, 0); // get decode result
stamp_res = FunASRGetStamp(Result);
stamp_sents = FunASRGetStampSents(Result);
FunASRFreeResult(Result);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
} catch (std::exception const &e) {
std::cout << "error in decoder!!! "<<e.what() <<std::endl;
}
nlohmann::json jsonresult; // result json
jsonresult["text"] = asr_result; // put result in 'text'
jsonresult["mode"] = "offline";
jsonresult["is_final"] = false;
if (stamp_res != "") {
jsonresult["timestamp"] = stamp_res;
}
if (stamp_sents != "") {
try {
nlohmann::json json_stamp = nlohmann::json::parse(stamp_sents);
jsonresult["stamp_sents"] = json_stamp;
} catch (std::exception const &e) {
std::cout << "error:" << e.what();
jsonresult["stamp_sents"] = "";
}
}
jsonresult["wav_name"] = wav_name;
std::cout << "buffer.size=" << buffer->size()
<< ",result json=" << jsonresult.dump() << std::endl;
FunWfstDecoderUnloadHwsRes(session_msg->decoder_handle);
FunASRWfstDecoderUninit(session_msg->decoder_handle);
session_msg->status = 1;
session_msg->msg["asr_result"] = jsonresult;
return;
} else {
std::cout << "Sent empty msg";
nlohmann::json jsonresult; // result json
jsonresult["text"] = ""; // put result in 'text'
jsonresult["mode"] = "offline";
jsonresult["is_final"] = false;
jsonresult["wav_name"] = wav_name;
}
} catch (std::exception const &e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}
// init asr model
FUNASR_HANDLE ModelDecoder::initAsr(std::map<std::string, std::string> &model_path,
int thread_num) {
try {
// init model with api
asr_handle = FunOfflineInit(model_path, thread_num);
LOG(INFO) << "model successfully inited";
return asr_handle;
} catch (const std::exception &e) {
LOG(INFO) << e.what();
return nullptr;
}
}

View File

@ -0,0 +1,60 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#ifndef MODEL_DECODER_SERVER_H_
#define MODEL_DECODER_SERVER_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#define ASIO_STANDALONE 1 // not boost
#include <glog/logging.h>
#include <fstream>
#include <functional>
#include "asio.hpp"
#include "asr_sessions.h"
#include "com-define.h"
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
#include "util/text-utils.h"
class ModelDecoder {
public:
ModelDecoder(asio::io_context &io_decoder,
std::map<std::string, std::string> &model_path, int thread_num)
: io_decoder_(io_decoder) {
asr_handle = initAsr(model_path, thread_num);
}
void do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg);
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path, int thread_num);
asio::io_context &io_decoder_; // threads for asr decoder
FUNASR_HANDLE get_asr_handle()
{
return asr_handle;
}
private:
FUNASR_HANDLE asr_handle; // asr engine handle
bool isonline = false; // online or offline engine, now only support offline
};
#endif // MODEL_DECODER_SERVER_H_

245
runtime/http/bin/reply.cpp Normal file
View File

@ -0,0 +1,245 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// reply.cpp
// ~~~~~~~~~
//
// copy some codes from http://www.boost.org/
#include "reply.hpp"
#include <iostream>
#include <string>
namespace http {
namespace server2 {
namespace status_strings {
const std::string ok = "HTTP/1.0 200 OK\r\n";
const std::string created = "HTTP/1.0 201 Created\r\n";
const std::string accepted = "HTTP/1.0 202 Accepted\r\n";
const std::string no_content = "HTTP/1.0 204 No Content\r\n";
const std::string multiple_choices = "HTTP/1.0 300 Multiple Choices\r\n";
const std::string moved_permanently = "HTTP/1.0 301 Moved Permanently\r\n";
const std::string moved_temporarily = "HTTP/1.0 302 Moved Temporarily\r\n";
const std::string not_modified = "HTTP/1.0 304 Not Modified\r\n";
const std::string bad_request = "HTTP/1.0 400 Bad Request\r\n";
const std::string unauthorized = "HTTP/1.0 401 Unauthorized\r\n";
const std::string forbidden = "HTTP/1.0 403 Forbidden\r\n";
const std::string not_found = "HTTP/1.0 404 Not Found\r\n";
const std::string internal_server_error =
"HTTP/1.0 500 Internal Server Error\r\n";
const std::string not_implemented = "HTTP/1.0 501 Not Implemented\r\n";
const std::string bad_gateway = "HTTP/1.0 502 Bad Gateway\r\n";
const std::string service_unavailable = "HTTP/1.0 503 Service Unavailable\r\n";
asio::const_buffer to_buffer(reply::status_type status) {
switch (status) {
case reply::ok:
return asio::buffer(ok);
case reply::created:
return asio::buffer(created);
case reply::accepted:
return asio::buffer(accepted);
case reply::no_content:
return asio::buffer(no_content);
case reply::multiple_choices:
return asio::buffer(multiple_choices);
case reply::moved_permanently:
return asio::buffer(moved_permanently);
case reply::moved_temporarily:
return asio::buffer(moved_temporarily);
case reply::not_modified:
return asio::buffer(not_modified);
case reply::bad_request:
return asio::buffer(bad_request);
case reply::unauthorized:
return asio::buffer(unauthorized);
case reply::forbidden:
return asio::buffer(forbidden);
case reply::not_found:
return asio::buffer(not_found);
case reply::internal_server_error:
return asio::buffer(internal_server_error);
case reply::not_implemented:
return asio::buffer(not_implemented);
case reply::bad_gateway:
return asio::buffer(bad_gateway);
case reply::service_unavailable:
return asio::buffer(service_unavailable);
default:
return asio::buffer(internal_server_error);
}
}
} // namespace status_strings
namespace misc_strings {
const char name_value_separator[] = {':', ' '};
const char crlf[] = {'\r', '\n'};
} // namespace misc_strings
std::vector<::asio::const_buffer> reply::to_buffers() {
std::vector<::asio::const_buffer> buffers;
buffers.push_back(status_strings::to_buffer(status));
for (std::size_t i = 0; i < headers.size(); ++i) {
header &h = headers[i];
buffers.push_back(asio::buffer(h.name));
buffers.push_back(asio::buffer(misc_strings::name_value_separator));
buffers.push_back(asio::buffer(h.value));
buffers.push_back(asio::buffer(misc_strings::crlf));
}
buffers.push_back(asio::buffer(misc_strings::crlf));
buffers.push_back(asio::buffer(content));
return buffers;
}
namespace stock_replies {
const char ok[] = "";
const char created[] =
"<html>"
"<head><title>Created</title></head>"
"<body><h1>201 Created</h1></body>"
"</html>";
const char accepted[] =
"<html>"
"<head><title>Accepted</title></head>"
"<body><h1>202 Accepted</h1></body>"
"</html>";
const char no_content[] =
"<html>"
"<head><title>No Content</title></head>"
"<body><h1>204 Content</h1></body>"
"</html>";
const char multiple_choices[] =
"<html>"
"<head><title>Multiple Choices</title></head>"
"<body><h1>300 Multiple Choices</h1></body>"
"</html>";
const char moved_permanently[] =
"<html>"
"<head><title>Moved Permanently</title></head>"
"<body><h1>301 Moved Permanently</h1></body>"
"</html>";
const char moved_temporarily[] =
"<html>"
"<head><title>Moved Temporarily</title></head>"
"<body><h1>302 Moved Temporarily</h1></body>"
"</html>";
const char not_modified[] =
"<html>"
"<head><title>Not Modified</title></head>"
"<body><h1>304 Not Modified</h1></body>"
"</html>";
const char bad_request[] =
"<html>"
"<head><title>Bad Request</title></head>"
"<body><h1>400 Bad Request</h1></body>"
"</html>";
const char unauthorized[] =
"<html>"
"<head><title>Unauthorized</title></head>"
"<body><h1>401 Unauthorized</h1></body>"
"</html>";
const char forbidden[] =
"<html>"
"<head><title>Forbidden</title></head>"
"<body><h1>403 Forbidden</h1></body>"
"</html>";
const char not_found[] =
"<html>"
"<head><title>Not Found</title></head>"
"<body><h1>404 Not Found</h1></body>"
"</html>";
const char internal_server_error[] =
"<html>"
"<head><title>Internal Server Error</title></head>"
"<body><h1>500 Internal Server Error</h1></body>"
"</html>";
const char not_implemented[] =
"<html>"
"<head><title>Not Implemented</title></head>"
"<body><h1>501 Not Implemented</h1></body>"
"</html>";
const char bad_gateway[] =
"<html>"
"<head><title>Bad Gateway</title></head>"
"<body><h1>502 Bad Gateway</h1></body>"
"</html>";
const char service_unavailable[] =
"<html>"
"<head><title>Service Unavailable</title></head>"
"<body><h1>503 Service Unavailable</h1></body>"
"</html>";
std::string to_string(reply::status_type status) {
switch (status) {
case reply::ok:
return ok;
case reply::created:
return created;
case reply::accepted:
return accepted;
case reply::no_content:
return no_content;
case reply::multiple_choices:
return multiple_choices;
case reply::moved_permanently:
return moved_permanently;
case reply::moved_temporarily:
return moved_temporarily;
case reply::not_modified:
return not_modified;
case reply::bad_request:
return bad_request;
case reply::unauthorized:
return unauthorized;
case reply::forbidden:
return forbidden;
case reply::not_found:
return not_found;
case reply::internal_server_error:
return internal_server_error;
case reply::not_implemented:
return not_implemented;
case reply::bad_gateway:
return bad_gateway;
case reply::service_unavailable:
return service_unavailable;
default:
return internal_server_error;
}
}
} // namespace stock_replies
reply reply::stock_reply(std::string jsonresult) {
reply rep;
rep.status = reply::ok;
rep.content = jsonresult+"\n";
rep.headers.resize(2);
rep.headers[0].name = "Content-Length";
rep.headers[0].value = std::to_string(rep.content.size());
rep.headers[1].name = "Content-Type";
rep.headers[1].value = "text/html;charset=utf-8";
return rep;
}
reply reply::stock_reply(reply::status_type status) {
reply rep;
rep.status = status;
rep.content = stock_replies::to_string(status);
rep.headers.resize(2);
rep.headers[0].name = "Content-Length";
rep.headers[0].value = std::to_string(rep.content.size());
rep.headers[1].name = "Content-Type";
rep.headers[1].value = "text/html";
return rep;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,64 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// reply.hpp
// ~~~~~~~~~
//
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_REPLY_HPP
#define HTTP_SERVER2_REPLY_HPP
#include <asio.hpp>
#include <string>
#include <vector>
#include "header.hpp"
namespace http {
namespace server2 {
/// A reply to be sent to a client.
struct reply {
/// The status of the reply.
enum status_type {
ok = 200,
created = 201,
accepted = 202,
no_content = 204,
multiple_choices = 300,
moved_permanently = 301,
moved_temporarily = 302,
not_modified = 304,
bad_request = 400,
unauthorized = 401,
forbidden = 403,
not_found = 404,
internal_server_error = 500,
not_implemented = 501,
bad_gateway = 502,
service_unavailable = 503
} status;
/// The headers to be included in the reply.
std::vector<header> headers;
/// The content to be sent in the reply.
std::string content;
/// Convert the reply into a vector of buffers. The buffers do not own the
/// underlying memory blocks, therefore the reply object must remain valid and
/// not be changed until the write operation has completed.
std::vector<::asio::const_buffer> to_buffers();
/// Get a stock reply.
static reply stock_reply(status_type status);
static reply stock_reply(std::string jsonresult);
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_REPLY_HPP

113
runtime/http/bin/server.cpp Normal file
View File

@ -0,0 +1,113 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.cpp
// copy some codes from http://www.boost.org/
#include "server.hpp"
#include <signal.h>
#include <fstream>
#include <iostream>
#include <utility>
#include "util.h"
namespace http {
namespace server2 {
server::server(const std::string &address, const std::string &port,
const std::string &doc_root, std::size_t io_context_pool_size,
asio::io_context &decoder_context,
std::map<std::string, std::string> &model_path, int thread_num)
: io_context_pool_(io_context_pool_size),
signals_(io_context_pool_.get_io_context()),
acceptor_(io_context_pool_.get_io_context()),
decoder_context(decoder_context) {
// Register to handle the signals that indicate when the server should exit.
// It is safe to register for the same signal multiple times in a program,
// provided all registration for the specified signal is made through Asio.
try {
model_decoder =
std::make_shared<ModelDecoder>(decoder_context, model_path, thread_num);
LOG(INFO) << "try to listen on port:" << port << std::endl;
LOG(INFO) << "still not work, pls wait... " << std::endl;
LOG(INFO) << "if always waiting here, may be port in used, pls change the "
"port or kill pre-process!"
<< std::endl;
atom_id = 0;
// init model with api
signals_.add(SIGINT);
signals_.add(SIGTERM);
#if defined(SIGQUIT)
signals_.add(SIGQUIT);
#endif // defined(SIGQUIT)
do_await_stop();
// Open the acceptor with the option to reuse the address (i.e.
// SO_REUSEADDR).
asio::ip::tcp::resolver resolver(acceptor_.get_executor());
asio::ip::tcp::endpoint endpoint = *resolver.resolve(address, port).begin();
acceptor_.open(endpoint.protocol());
acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true));
acceptor_.bind(endpoint);
acceptor_.listen();
do_accept();
std::cout << "use curl to test,just as " << std::endl;
std::cout << "curl -F \"file=@example.wav\" 127.0.0.1:80" << std::endl;
std::cout << "http post only support offline mode, if you want online "
"mode, pls try websocket!"
<< std::endl;
std::cout << "now succeed listen on port " << address << ":" << port
<< ", can accept data now!!!" << std::endl;
} catch (const std::exception &e) {
std::cout << "error:" << e.what();
}
}
void server::run() { io_context_pool_.run(); }
void server::do_accept() {
acceptor_.async_accept(
io_context_pool_.get_io_context(),
[this](asio::error_code ec, asio::ip::tcp::socket socket) {
// Check whether the server was stopped by a signal before this
// completion handler had a chance to run.
if (!acceptor_.is_open()) {
return;
}
if (!ec) {
std::lock_guard<std::mutex> lk(m_lock);
atom_id = atom_id + 1;
std::make_shared<connection>(std::move(socket), decoder_context,
(atom_id).load(), model_decoder)
->start();
}
do_accept();
});
}
void server::do_await_stop() {
signals_.async_wait([this](asio::error_code /*ec*/, int /*signo*/) {
io_context_pool_.stop();
});
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,71 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.hpp
// ~~~~~~~~~~
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_SERVER_HPP
#define HTTP_SERVER2_SERVER_HPP
#include <asio.hpp>
#include <atomic>
#include <string>
#include "connection.hpp"
#include "funasrruntime.h"
#include "io_context_pool.hpp"
#include "model-decoder.h"
#include "util.h"
namespace http {
namespace server2 {
/// The top-level class of the HTTP server.
class server {
public:
server(const server &) = delete;
server &operator=(const server &) = delete;
/// Construct the server to listen on the specified TCP address and port, and
/// serve up files from the given directory.
explicit server(const std::string &address, const std::string &port,
const std::string &doc_root, std::size_t io_context_pool_size,
asio::io_context &decoder_context,
std::map<std::string, std::string> &model_path,
int thread_num);
/// Run the server's io_context loop.
void run();
private:
/// Perform an asynchronous accept operation.
void do_accept();
/// Wait for a request to stop the server.
void do_await_stop();
/// The pool of io_context objects used to perform asynchronous operations.
io_context_pool io_context_pool_;
asio::io_context &decoder_context;
/// The signal_set is used to register for process termination notifications.
asio::signal_set signals_;
/// Acceptor used to listen for incoming connections.
asio::ip::tcp::acceptor acceptor_;
std::shared_ptr<ModelDecoder> model_decoder;
std::atomic<int> atom_id;
std::mutex m_lock;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_SERVER_HPP

58
runtime/http/readme.md Normal file
View File

@ -0,0 +1,58 @@
# Advanced Development Guide (File transcription service) ([click](../docs/SDK_advanced_guide_offline.md))
# Real-time Speech Transcription Service Development Guide ([click](../docs/SDK_advanced_guide_online.md))
# If you want to compile the file yourself, you can follow the steps below.
## Building for Linux/Unix
### Download onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### Download ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### Install deps
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### Build runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### test
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### run
```shell
./funasr-http-server \
--lm-dir '' \
--itn-dir '' \
--download-model-dir ${download_model_dir} \
--model-dir ${model_dir} \
--vad-dir ${vad_dir} \
--punc-dir ${punc_dir} \
--decoder-thread-num ${decoder_thread_num} \
--io-thread-num ${io_thread_num} \
--port ${port} \
```

61
runtime/http/readme_zh.md Normal file
View File

@ -0,0 +1,61 @@
# FunASR离线文件转写服务开发指南([点击此处](../docs/SDK_advanced_guide_offline_zh.md))
# FunASR实时语音听写服务开发指南([点击此处](../docs/SDK_advanced_guide_online_zh.md))
# 如果您想自己编译文件,可以参考下述步骤
## Linux/Unix 平台编译
### 下载 onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### 下载 ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### 安装依赖
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### 编译 runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### 测试
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### 运行
```shell
./funasr-http-server \
--lm-dir '' \
--itn-dir '' \
--download-model-dir ${download_model_dir} \
--model-dir ${model_dir} \
--vad-dir ${vad_dir} \
--punc-dir ${punc_dir} \
--decoder-thread-num ${decoder_thread_num} \
--io-thread-num ${io_thread_num} \
--port ${port} \
```

View File

@ -0,0 +1,15 @@
#### Download onnxruntime
```shell
bash third_party/download_onnxruntime.sh
```
#### Download ffmpeg
```shell
bash third_party/download_ffmpeg.sh
```
#### Install openblas and openssl
```shell
sudo apt-get install libopenblas-dev libssl-dev #ubuntu
# sudo yum -y install openblas-devel openssl-devel #centos
```