mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' into dev_wjm_infer
This commit is contained in:
commit
7375292887
Binary file not shown.
|
Before Width: | Height: | Size: 188 KiB After Width: | Height: | Size: 183 KiB |
@ -17,7 +17,7 @@ inference_diar_pipline = pipeline(
|
|||||||
diar_model_config="sond.yaml",
|
diar_model_config="sond.yaml",
|
||||||
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
||||||
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
|
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
|
||||||
sv_model_revision="master",
|
sv_model_revision="v1.2.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
# use audio_list as the input, where the first one is the record to be detected
|
# use audio_list as the input, where the first one is the record to be detected
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
#include "precomp.h"
|
#include "precomp.h"
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
namespace funasr {
|
namespace funasr {
|
||||||
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||||
{
|
{
|
||||||
// VAD model
|
// VAD model
|
||||||
if(model_path.find(VAD_DIR) != model_path.end()){
|
if(model_path.find(VAD_DIR) != model_path.end()){
|
||||||
use_vad = true;
|
|
||||||
string vad_model_path;
|
string vad_model_path;
|
||||||
string vad_cmvn_path;
|
string vad_cmvn_path;
|
||||||
string vad_config_path;
|
string vad_config_path;
|
||||||
@ -16,8 +16,16 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
|||||||
}
|
}
|
||||||
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
|
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
|
||||||
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
|
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
|
||||||
vad_handle = make_unique<FsmnVad>();
|
if (access(vad_model_path.c_str(), F_OK) != 0 ||
|
||||||
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
|
access(vad_cmvn_path.c_str(), F_OK) != 0 ||
|
||||||
|
access(vad_config_path.c_str(), F_OK) != 0 )
|
||||||
|
{
|
||||||
|
LOG(INFO) << "VAD model file is not exist, skip load vad model.";
|
||||||
|
}else{
|
||||||
|
vad_handle = make_unique<FsmnVad>();
|
||||||
|
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
|
||||||
|
use_vad = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AM model
|
// AM model
|
||||||
@ -39,7 +47,6 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
|||||||
|
|
||||||
// PUNC model
|
// PUNC model
|
||||||
if(model_path.find(PUNC_DIR) != model_path.end()){
|
if(model_path.find(PUNC_DIR) != model_path.end()){
|
||||||
use_punc = true;
|
|
||||||
string punc_model_path;
|
string punc_model_path;
|
||||||
string punc_config_path;
|
string punc_config_path;
|
||||||
|
|
||||||
@ -49,8 +56,15 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
|||||||
}
|
}
|
||||||
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
|
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
|
||||||
|
|
||||||
punc_handle = make_unique<CTTransformer>();
|
if (access(punc_model_path.c_str(), F_OK) != 0 ||
|
||||||
punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
|
access(punc_config_path.c_str(), F_OK) != 0 )
|
||||||
|
{
|
||||||
|
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
|
||||||
|
}else{
|
||||||
|
punc_handle = make_unique<CTTransformer>();
|
||||||
|
punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
|
||||||
|
use_punc = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -71,6 +71,8 @@ print(args)
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
voices = Queue()
|
voices = Queue()
|
||||||
|
offline_msg_done=False
|
||||||
|
|
||||||
ibest_writer = None
|
ibest_writer = None
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
writer = DatadirWriter(args.output_dir)
|
writer = DatadirWriter(args.output_dir)
|
||||||
@ -158,13 +160,20 @@ async def record_from_scp(chunk_begin, chunk_size):
|
|||||||
message = json.dumps({"is_speaking": is_speaking})
|
message = json.dumps({"is_speaking": is_speaking})
|
||||||
#voices.put(message)
|
#voices.put(message)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
# print("data_chunk: ", len(data_chunk))
|
|
||||||
# print(voices.qsize())
|
|
||||||
sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
|
sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
|
||||||
await asyncio.sleep(sleep_duration)
|
await asyncio.sleep(sleep_duration)
|
||||||
|
# when all data sent, we need to close websocket
|
||||||
while not voices.empty():
|
while not voices.empty():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
# offline model need to wait for message recved
|
||||||
|
|
||||||
|
if args.mode=="offline":
|
||||||
|
global offline_msg_done
|
||||||
|
while not offline_msg_done:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
@ -173,7 +182,7 @@ async def record_from_scp(chunk_begin, chunk_size):
|
|||||||
|
|
||||||
|
|
||||||
async def message(id):
|
async def message(id):
|
||||||
global websocket,voices
|
global websocket,voices,offline_msg_done
|
||||||
text_print = ""
|
text_print = ""
|
||||||
text_print_2pass_online = ""
|
text_print_2pass_online = ""
|
||||||
text_print_2pass_offline = ""
|
text_print_2pass_offline = ""
|
||||||
@ -183,7 +192,6 @@ async def message(id):
|
|||||||
meg = await websocket.recv()
|
meg = await websocket.recv()
|
||||||
meg = json.loads(meg)
|
meg = json.loads(meg)
|
||||||
wav_name = meg.get("wav_name", "demo")
|
wav_name = meg.get("wav_name", "demo")
|
||||||
# print(wav_name)
|
|
||||||
text = meg["text"]
|
text = meg["text"]
|
||||||
if ibest_writer is not None:
|
if ibest_writer is not None:
|
||||||
ibest_writer["text"][wav_name] = text
|
ibest_writer["text"][wav_name] = text
|
||||||
@ -198,6 +206,7 @@ async def message(id):
|
|||||||
text_print = text_print[-args.words_max_print:]
|
text_print = text_print[-args.words_max_print:]
|
||||||
os.system('clear')
|
os.system('clear')
|
||||||
print("\rpid" + str(id) + ": " + text_print)
|
print("\rpid" + str(id) + ": " + text_print)
|
||||||
|
offline_msg_done=True
|
||||||
else:
|
else:
|
||||||
if meg["mode"] == "2pass-online":
|
if meg["mode"] == "2pass-online":
|
||||||
text_print_2pass_online += "{}".format(text)
|
text_print_2pass_online += "{}".format(text)
|
||||||
@ -233,8 +242,10 @@ async def ws_client(id, chunk_begin, chunk_size):
|
|||||||
if args.audio_in is None:
|
if args.audio_in is None:
|
||||||
chunk_begin=0
|
chunk_begin=0
|
||||||
chunk_size=1
|
chunk_size=1
|
||||||
global websocket,voices
|
global websocket,voices,offline_msg_done
|
||||||
|
|
||||||
for i in range(chunk_begin,chunk_begin+chunk_size):
|
for i in range(chunk_begin,chunk_begin+chunk_size):
|
||||||
|
offline_msg_done=False
|
||||||
voices = Queue()
|
voices = Queue()
|
||||||
if args.ssl == 1:
|
if args.ssl == 1:
|
||||||
ssl_context = ssl.SSLContext()
|
ssl_context = ssl.SSLContext()
|
||||||
@ -251,7 +262,7 @@ async def ws_client(id, chunk_begin, chunk_size):
|
|||||||
else:
|
else:
|
||||||
task = asyncio.create_task(record_microphone())
|
task = asyncio.create_task(record_microphone())
|
||||||
#task2 = asyncio.create_task(ws_send())
|
#task2 = asyncio.create_task(ws_send())
|
||||||
task3 = asyncio.create_task(message(id))
|
task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
|
||||||
await asyncio.gather(task, task3)
|
await asyncio.gather(task, task3)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|||||||
@ -56,8 +56,8 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
|
|||||||
# install openssl first apt-get install libssl-dev
|
# install openssl first apt-get install libssl-dev
|
||||||
find_package(OpenSSL REQUIRED)
|
find_package(OpenSSL REQUIRED)
|
||||||
|
|
||||||
add_executable(funasr-ws-server "funasr-ws-server.cpp" "websocket-server.cpp")
|
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
|
||||||
add_executable(funasr-ws-client "funasr-ws-client.cpp")
|
add_executable(funasr-wss-client "funasr-wss-client.cpp")
|
||||||
|
|
||||||
target_link_libraries(funasr-ws-client PUBLIC funasr ssl crypto)
|
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
|
||||||
target_link_libraries(funasr-ws-server PUBLIC funasr ssl crypto)
|
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
|
||||||
|
|||||||
@ -5,7 +5,14 @@
|
|||||||
/* 2022-2023 by zhaomingwork */
|
/* 2022-2023 by zhaomingwork */
|
||||||
|
|
||||||
// client for websocket, support multiple threads
|
// client for websocket, support multiple threads
|
||||||
// Usage: websocketclient server_ip port wav_path threads_num
|
// ./funasr-ws-client --server-ip <string>
|
||||||
|
// --port <string>
|
||||||
|
// --wav-path <string>
|
||||||
|
// [--thread-num <int>]
|
||||||
|
// [--is-ssl <int>] [--]
|
||||||
|
// [--version] [-h]
|
||||||
|
// example:
|
||||||
|
// ./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
|
||||||
|
|
||||||
#define ASIO_STANDALONE 1
|
#define ASIO_STANDALONE 1
|
||||||
#include <websocketpp/client.hpp>
|
#include <websocketpp/client.hpp>
|
||||||
@ -55,7 +62,7 @@ context_ptr OnTlsInit(websocketpp::connection_hdl) {
|
|||||||
asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
|
asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
|
||||||
|
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
std::cout << e.what() << std::endl;
|
LOG(ERROR) << e.what();
|
||||||
}
|
}
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
@ -99,7 +106,16 @@ class WebsocketClient {
|
|||||||
const std::string& payload = msg->get_payload();
|
const std::string& payload = msg->get_payload();
|
||||||
switch (msg->get_opcode()) {
|
switch (msg->get_opcode()) {
|
||||||
case websocketpp::frame::opcode::text:
|
case websocketpp::frame::opcode::text:
|
||||||
std::cout << "on_message = " << payload << std::endl;
|
total_num=total_num+1;
|
||||||
|
LOG(INFO)<<total_num<<",on_message = " << payload;
|
||||||
|
if((total_num+1)==wav_index)
|
||||||
|
{
|
||||||
|
websocketpp::lib::error_code ec;
|
||||||
|
m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
|
||||||
|
if (ec){
|
||||||
|
LOG(ERROR)<< "Error closing connection " << ec.message();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,12 +148,8 @@ class WebsocketClient {
|
|||||||
}
|
}
|
||||||
send_wav_data(wav_list[i], wav_ids[i]);
|
send_wav_data(wav_list[i], wav_ids[i]);
|
||||||
}
|
}
|
||||||
WaitABit();
|
WaitABit();
|
||||||
m_client.close(m_hdl,websocketpp::close::status::going_away, "", ec);
|
|
||||||
if (ec) {
|
|
||||||
std::cout << "> Error closing connection " << ec.message() << std::endl;
|
|
||||||
}
|
|
||||||
//send_wav_data();
|
|
||||||
asio_thread.join();
|
asio_thread.join();
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -206,7 +218,7 @@ class WebsocketClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (wait) {
|
if (wait) {
|
||||||
std::cout << "wait.." << m_open << std::endl;
|
LOG(INFO) << "wait.." << m_open;
|
||||||
WaitABit();
|
WaitABit();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -236,7 +248,7 @@ class WebsocketClient {
|
|||||||
// send data to server
|
// send data to server
|
||||||
m_client.send(m_hdl, iArray, len * sizeof(short),
|
m_client.send(m_hdl, iArray, len * sizeof(short),
|
||||||
websocketpp::frame::opcode::binary, ec);
|
websocketpp::frame::opcode::binary, ec);
|
||||||
std::cout << "sended data len=" << len * sizeof(short) << std::endl;
|
LOG(INFO) << "sended data len=" << len * sizeof(short);
|
||||||
// The most likely error that we will get is that the connection is
|
// The most likely error that we will get is that the connection is
|
||||||
// not in the right state. Usually this means we tried to send a
|
// not in the right state. Usually this means we tried to send a
|
||||||
// message to a connection that was closed or in the process of
|
// message to a connection that was closed or in the process of
|
||||||
@ -247,14 +259,13 @@ class WebsocketClient {
|
|||||||
"Send Error: " + ec.message());
|
"Send Error: " + ec.message());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
// WaitABit();
|
||||||
WaitABit();
|
|
||||||
}
|
}
|
||||||
nlohmann::json jsonresult;
|
nlohmann::json jsonresult;
|
||||||
jsonresult["is_speaking"] = false;
|
jsonresult["is_speaking"] = false;
|
||||||
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
|
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
|
||||||
ec);
|
ec);
|
||||||
WaitABit();
|
// WaitABit();
|
||||||
}
|
}
|
||||||
websocketpp::client<T> m_client;
|
websocketpp::client<T> m_client;
|
||||||
|
|
||||||
@ -263,6 +274,7 @@ class WebsocketClient {
|
|||||||
websocketpp::lib::mutex m_lock;
|
websocketpp::lib::mutex m_lock;
|
||||||
bool m_open;
|
bool m_open;
|
||||||
bool m_done;
|
bool m_done;
|
||||||
|
int total_num=0;
|
||||||
};
|
};
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
@ -5,7 +5,7 @@
|
|||||||
/* 2022-2023 by zhaomingwork */
|
/* 2022-2023 by zhaomingwork */
|
||||||
|
|
||||||
// io server
|
// io server
|
||||||
// Usage:websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
|
// Usage:funasr-ws-server [--model_thread_num <int>] [--decoder_thread_num <int>]
|
||||||
// [--io_thread_num <int>] [--port <int>] [--listen_ip
|
// [--io_thread_num <int>] [--port <int>] [--listen_ip
|
||||||
// <string>] [--punc-quant <string>] [--punc-dir <string>]
|
// <string>] [--punc-quant <string>] [--punc-dir <string>]
|
||||||
// [--vad-quant <string>] [--vad-dir <string>] [--quantize
|
// [--vad-quant <string>] [--vad-dir <string>] [--quantize
|
||||||
@ -15,44 +15,43 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||||
std::map<std::string, std::string>& model_path) {
|
std::map<std::string, std::string>& model_path) {
|
||||||
if (value_arg.isSet()) {
|
|
||||||
model_path.insert({key, value_arg.getValue()});
|
model_path.insert({key, value_arg.getValue()});
|
||||||
LOG(INFO) << key << " : " << value_arg.getValue();
|
LOG(INFO) << key << " : " << value_arg.getValue();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
try {
|
try {
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
FLAGS_logtostderr = true;
|
FLAGS_logtostderr = true;
|
||||||
|
|
||||||
TCLAP::CmdLine cmd("websocketmain", ' ', "1.0");
|
TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
|
||||||
TCLAP::ValueArg<std::string> model_dir(
|
TCLAP::ValueArg<std::string> model_dir(
|
||||||
"", MODEL_DIR,
|
"", MODEL_DIR,
|
||||||
"the asr model path, which contains model.onnx, config.yaml, am.mvn",
|
"default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn",
|
||||||
true, "", "string");
|
false, "/workspace/models/asr", "string");
|
||||||
TCLAP::ValueArg<std::string> quantize(
|
TCLAP::ValueArg<std::string> quantize(
|
||||||
"", QUANTIZE,
|
"", QUANTIZE,
|
||||||
"false (Default), load the model of model.onnx in model_dir. If set "
|
"true (Default), load the model of model.onnx in model_dir. If set "
|
||||||
"true, load the model of model_quant.onnx in model_dir",
|
"true, load the model of model_quant.onnx in model_dir",
|
||||||
false, "false", "string");
|
false, "true", "string");
|
||||||
TCLAP::ValueArg<std::string> vad_dir(
|
TCLAP::ValueArg<std::string> vad_dir(
|
||||||
"", VAD_DIR,
|
"", VAD_DIR,
|
||||||
"the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
|
"default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
|
||||||
false, "", "string");
|
false, "/workspace/models/vad", "string");
|
||||||
TCLAP::ValueArg<std::string> vad_quant(
|
TCLAP::ValueArg<std::string> vad_quant(
|
||||||
"", VAD_QUANT,
|
"", VAD_QUANT,
|
||||||
"false (Default), load the model of model.onnx in vad_dir. If set "
|
"true (Default), load the model of model.onnx in vad_dir. If set "
|
||||||
"true, load the model of model_quant.onnx in vad_dir",
|
"true, load the model of model_quant.onnx in vad_dir",
|
||||||
false, "false", "string");
|
false, "true", "string");
|
||||||
TCLAP::ValueArg<std::string> punc_dir(
|
TCLAP::ValueArg<std::string> punc_dir(
|
||||||
"", PUNC_DIR,
|
"", PUNC_DIR,
|
||||||
"the punc model path, which contains model.onnx, punc.yaml", false, "",
|
"default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml",
|
||||||
|
false, "/workspace/models/punc",
|
||||||
"string");
|
"string");
|
||||||
TCLAP::ValueArg<std::string> punc_quant(
|
TCLAP::ValueArg<std::string> punc_quant(
|
||||||
"", PUNC_QUANT,
|
"", PUNC_QUANT,
|
||||||
"false (Default), load the model of model.onnx in punc_dir. If set "
|
"true (Default), load the model of model.onnx in punc_dir. If set "
|
||||||
"true, load the model of model_quant.onnx in punc_dir",
|
"true, load the model of model_quant.onnx in punc_dir",
|
||||||
false, "false", "string");
|
false, "true", "string");
|
||||||
|
|
||||||
TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
|
TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
|
||||||
"0.0.0.0", "string");
|
"0.0.0.0", "string");
|
||||||
@ -64,10 +63,12 @@ int main(int argc, char* argv[]) {
|
|||||||
TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
|
TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
|
||||||
"model_thread_num", false, 1, "int");
|
"model_thread_num", false, 1, "int");
|
||||||
|
|
||||||
TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
|
TCLAP::ValueArg<std::string> certfile("", "certfile",
|
||||||
"string");
|
"default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",
|
||||||
TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
|
false, "../../../ssl_key/server.crt", "string");
|
||||||
"string");
|
TCLAP::ValueArg<std::string> keyfile("", "keyfile",
|
||||||
|
"default: ../../../ssl_key/server.key, path of keyfile for WSS connection",
|
||||||
|
false, "../../../ssl_key/server.key", "string");
|
||||||
|
|
||||||
cmd.add(certfile);
|
cmd.add(certfile);
|
||||||
cmd.add(keyfile);
|
cmd.add(keyfile);
|
||||||
@ -51,7 +51,7 @@ make
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
cd bin
|
cd bin
|
||||||
./funasr-ws-server [--model_thread_num <int>] [--decoder_thread_num <int>]
|
./funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num <int>]
|
||||||
[--io_thread_num <int>] [--port <int>] [--listen_ip
|
[--io_thread_num <int>] [--port <int>] [--listen_ip
|
||||||
<string>] [--punc-quant <string>] [--punc-dir <string>]
|
<string>] [--punc-quant <string>] [--punc-dir <string>]
|
||||||
[--vad-quant <string>] [--vad-dir <string>] [--quantize
|
[--vad-quant <string>] [--vad-dir <string>] [--quantize
|
||||||
@ -59,19 +59,19 @@ cd bin
|
|||||||
[--certfile <string>] [--] [--version] [-h]
|
[--certfile <string>] [--] [--version] [-h]
|
||||||
Where:
|
Where:
|
||||||
--model-dir <string>
|
--model-dir <string>
|
||||||
(required) the asr model path, which contains model.onnx, config.yaml, am.mvn
|
default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn
|
||||||
--quantize <string>
|
--quantize <string>
|
||||||
false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
|
true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
|
||||||
|
|
||||||
--vad-dir <string>
|
--vad-dir <string>
|
||||||
the vad model path, which contains model.onnx, vad.yaml, vad.mvn
|
default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn
|
||||||
--vad-quant <string>
|
--vad-quant <string>
|
||||||
false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
|
true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
|
||||||
|
|
||||||
--punc-dir <string>
|
--punc-dir <string>
|
||||||
the punc model path, which contains model.onnx, punc.yaml
|
default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml
|
||||||
--punc-quant <string>
|
--punc-quant <string>
|
||||||
false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
|
true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
|
||||||
|
|
||||||
--decoder_thread_num <int>
|
--decoder_thread_num <int>
|
||||||
number of threads for decoder, default:8
|
number of threads for decoder, default:8
|
||||||
@ -80,21 +80,18 @@ Where:
|
|||||||
--port <int>
|
--port <int>
|
||||||
listen port, default:8889
|
listen port, default:8889
|
||||||
--certfile <string>
|
--certfile <string>
|
||||||
path of certficate for WSS connection. if it is empty, it will be in WS mode.
|
default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.
|
||||||
--keyfile <string>
|
--keyfile <string>
|
||||||
path of keyfile for WSS connection
|
default: ../../../ssl_key/server.key, path of keyfile for WSS connection
|
||||||
|
|
||||||
Required: --model-dir <string>
|
|
||||||
If use vad, please add: --vad-dir <string>
|
|
||||||
If use punc, please add: --punc-dir <string>
|
|
||||||
example:
|
example:
|
||||||
funasr-ws-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
./funasr-wss-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run websocket client test
|
## Run websocket client test
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
./funasr-ws-client --server-ip <string>
|
./funasr-wss-client --server-ip <string>
|
||||||
--port <string>
|
--port <string>
|
||||||
--wav-path <string>
|
--wav-path <string>
|
||||||
[--thread-num <int>]
|
[--thread-num <int>]
|
||||||
@ -119,7 +116,7 @@ Where:
|
|||||||
is-ssl is 1 means use wss connection, or use ws connection
|
is-ssl is 1 means use wss connection, or use ws connection
|
||||||
|
|
||||||
example:
|
example:
|
||||||
./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
|
./funasr-wss-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 1
|
||||||
|
|
||||||
result json, example like:
|
result json, example like:
|
||||||
{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"wav2"}
|
{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"wav2"}
|
||||||
|
|||||||
@ -22,12 +22,11 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
|||||||
std::string& s_keyfile) {
|
std::string& s_keyfile) {
|
||||||
namespace asio = websocketpp::lib::asio;
|
namespace asio = websocketpp::lib::asio;
|
||||||
|
|
||||||
std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
|
LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
|
||||||
std::cout << "using TLS mode: "
|
LOG(INFO) << "using TLS mode: "
|
||||||
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
|
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
|
||||||
: "Mozilla Intermediate")
|
: "Mozilla Intermediate");
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
|
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
|
||||||
asio::ssl::context::sslv23);
|
asio::ssl::context::sslv23);
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
|||||||
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
|
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
|
||||||
|
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
std::cout << "Exception: " << e.what() << std::endl;
|
LOG(INFO) << "Exception: " << e.what();
|
||||||
}
|
}
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
@ -86,8 +85,7 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
|||||||
ec);
|
ec);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "buffer.size=" << buffer.size()
|
LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
|
||||||
<< ",result json=" << jsonresult.dump() << std::endl;
|
|
||||||
if (!isonline) {
|
if (!isonline) {
|
||||||
// close the client if it is not online asr
|
// close the client if it is not online asr
|
||||||
// server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
|
// server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
|
||||||
@ -110,14 +108,14 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
|||||||
data_msg->samples = std::make_shared<std::vector<char>>();
|
data_msg->samples = std::make_shared<std::vector<char>>();
|
||||||
data_msg->msg = nlohmann::json::parse("{}");
|
data_msg->msg = nlohmann::json::parse("{}");
|
||||||
data_map.emplace(hdl, data_msg);
|
data_map.emplace(hdl, data_msg);
|
||||||
std::cout << "on_open, active connections: " << data_map.size() << std::endl;
|
LOG(INFO) << "on_open, active connections: " << data_map.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
|
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
|
||||||
scoped_lock guard(m_lock);
|
scoped_lock guard(m_lock);
|
||||||
data_map.erase(hdl); // remove data vector when connection is closed
|
data_map.erase(hdl); // remove data vector when connection is closed
|
||||||
|
|
||||||
std::cout << "on_close, active connections: " << data_map.size() << std::endl;
|
LOG(INFO) << "on_close, active connections: " << data_map.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove closed connection
|
// remove closed connection
|
||||||
@ -143,7 +141,7 @@ void WebSocketServer::check_and_clean_connection() {
|
|||||||
}
|
}
|
||||||
for (auto hdl : to_remove) {
|
for (auto hdl : to_remove) {
|
||||||
data_map.erase(hdl);
|
data_map.erase(hdl);
|
||||||
std::cout << "remove one connection " << std::endl;
|
LOG(INFO)<< "remove one connection ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||||
@ -161,7 +159,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
|||||||
|
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (sample_data_p == nullptr) {
|
if (sample_data_p == nullptr) {
|
||||||
std::cout << "error when fetch sample data vector" << std::endl;
|
LOG(INFO) << "error when fetch sample data vector";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +174,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
|||||||
|
|
||||||
if (jsonresult["is_speaking"] == false ||
|
if (jsonresult["is_speaking"] == false ||
|
||||||
jsonresult["is_finished"] == true) {
|
jsonresult["is_finished"] == true) {
|
||||||
std::cout << "client done" << std::endl;
|
LOG(INFO) << "client done";
|
||||||
|
|
||||||
if (isonline) {
|
if (isonline) {
|
||||||
// do_close(ws);
|
// do_close(ws);
|
||||||
@ -225,9 +223,9 @@ void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
|
|||||||
// init model with api
|
// init model with api
|
||||||
|
|
||||||
asr_hanlde = FunOfflineInit(model_path, thread_num);
|
asr_hanlde = FunOfflineInit(model_path, thread_num);
|
||||||
std::cout << "model ready" << std::endl;
|
LOG(INFO) << "model successfully inited";
|
||||||
|
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
std::cout << e.what() << std::endl;
|
LOG(INFO) << e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,6 +87,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "国务院发展研究中心市场经济研究所副所长邓郁松认为"
|
||||||
|
|
||||||
def test_paraformer_large_aishell1(self):
|
def test_paraformer_large_aishell1(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
@ -95,6 +96,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
|
||||||
|
|
||||||
def test_paraformer_large_aishell2(self):
|
def test_paraformer_large_aishell2(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
@ -103,6 +105,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
|
||||||
|
|
||||||
def test_paraformer_large_common(self):
|
def test_paraformer_large_common(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
@ -111,6 +114,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
|
||||||
|
|
||||||
def test_paraformer_large_online_common(self):
|
def test_paraformer_large_online_common(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
@ -119,6 +123,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎大 家来 体验达 摩院推 出的 语音识 别模 型"
|
||||||
|
|
||||||
def test_paraformer_online_common(self):
|
def test_paraformer_online_common(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
@ -127,6 +132,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr inference result: {0}".format(rec_result))
|
logger.info("asr inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎 大家来 体验达 摩院推 出的 语音识 别模 型"
|
||||||
|
|
||||||
def test_paraformer_tiny_commandword(self):
|
def test_paraformer_tiny_commandword(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class TestParaformerInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
logger.info("asr_vad_punc inference result: {0}".format(rec_result))
|
logger.info("asr_vad_punc inference result: {0}".format(rec_result))
|
||||||
|
assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型。"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -24,16 +24,15 @@ class TestXVectorInferencePipelines(unittest.TestCase):
|
|||||||
rec_result = inference_sv_pipline(audio_in=(
|
rec_result = inference_sv_pipline(audio_in=(
|
||||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
||||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
|
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
|
||||||
assert abs(rec_result["scores"][0] - 0.85) < 0.1 and abs(rec_result["scores"][1] - 0.14) < 0.1
|
assert abs(rec_result["scores"][0]-0.85) < 0.1 and abs(rec_result["scores"][1]-0.14) < 0.1
|
||||||
logger.info(f"Similarity {rec_result['scores']}")
|
logger.info(f"Similarity {rec_result['scores']}")
|
||||||
|
|
||||||
# different speaker
|
# different speaker
|
||||||
rec_result = inference_sv_pipline(audio_in=(
|
rec_result = inference_sv_pipline(audio_in=(
|
||||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
||||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
|
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
|
||||||
assert abs(rec_result["scores"][0] - 0.0) < 0.1 and abs(rec_result["scores"][1] - 1.0) < 0.1
|
assert abs(rec_result["scores"][0]-0.0) < 0.1 and abs(rec_result["scores"][1]-1.0) < 0.1
|
||||||
logger.info(f"Similarity {rec_result['scores']}")
|
logger.info(f"Similarity {rec_result['scores']}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user