mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update wbsocket for sensevoice & onnx models
This commit is contained in:
parent
67239ea39b
commit
3e44172c8b
@ -115,7 +115,7 @@ class WebsocketClient {
|
|||||||
|
|
||||||
// This method will block until the connection is complete
|
// This method will block until the connection is complete
|
||||||
void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids,
|
void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids,
|
||||||
int audio_fs, const std::unordered_map<std::string, int>& hws_map, int use_itn=1) {
|
int audio_fs, const std::unordered_map<std::string, int>& hws_map, int use_itn=1, int svs_itn=1) {
|
||||||
// Create a new connection to the given URI
|
// Create a new connection to the given URI
|
||||||
websocketpp::lib::error_code ec;
|
websocketpp::lib::error_code ec;
|
||||||
typename websocketpp::client<T>::connection_ptr con =
|
typename websocketpp::client<T>::connection_ptr con =
|
||||||
@ -147,7 +147,7 @@ class WebsocketClient {
|
|||||||
cv.wait(lock);
|
cv.wait(lock);
|
||||||
}
|
}
|
||||||
total_send += 1;
|
total_send += 1;
|
||||||
send_wav_data(wav_list[i], wav_ids[i], audio_fs, hws_map, send_hotword, use_itn);
|
send_wav_data(wav_list[i], wav_ids[i], audio_fs, hws_map, send_hotword, use_itn, svs_itn);
|
||||||
if(send_hotword){
|
if(send_hotword){
|
||||||
send_hotword = false;
|
send_hotword = false;
|
||||||
}
|
}
|
||||||
@ -186,7 +186,7 @@ class WebsocketClient {
|
|||||||
// send wav to server
|
// send wav to server
|
||||||
void send_wav_data(string wav_path, string wav_id, int audio_fs,
|
void send_wav_data(string wav_path, string wav_id, int audio_fs,
|
||||||
const std::unordered_map<std::string, int>& hws_map,
|
const std::unordered_map<std::string, int>& hws_map,
|
||||||
bool send_hotword, bool use_itn) {
|
bool send_hotword, bool use_itn, bool svs_itn) {
|
||||||
uint64_t count = 0;
|
uint64_t count = 0;
|
||||||
std::stringstream val;
|
std::stringstream val;
|
||||||
|
|
||||||
@ -239,9 +239,13 @@ class WebsocketClient {
|
|||||||
jsonbegin["wav_format"] = wav_format;
|
jsonbegin["wav_format"] = wav_format;
|
||||||
jsonbegin["audio_fs"] = sampling_rate;
|
jsonbegin["audio_fs"] = sampling_rate;
|
||||||
jsonbegin["itn"] = true;
|
jsonbegin["itn"] = true;
|
||||||
|
jsonbegin["svs_itn"] = true;
|
||||||
if(use_itn == 0){
|
if(use_itn == 0){
|
||||||
jsonbegin["itn"] = false;
|
jsonbegin["itn"] = false;
|
||||||
}
|
}
|
||||||
|
if(svs_itn == 0){
|
||||||
|
jsonbegin["svs_itn"] = false;
|
||||||
|
}
|
||||||
jsonbegin["is_speaking"] = true;
|
jsonbegin["is_speaking"] = true;
|
||||||
if(send_hotword){
|
if(send_hotword){
|
||||||
if(!hws_map.empty()){
|
if(!hws_map.empty()){
|
||||||
@ -368,6 +372,9 @@ int main(int argc, char* argv[]) {
|
|||||||
TCLAP::ValueArg<int> use_itn_(
|
TCLAP::ValueArg<int> use_itn_(
|
||||||
"", "use-itn",
|
"", "use-itn",
|
||||||
"use-itn is 1 means use itn, 0 means not use itn", false, 1, "int");
|
"use-itn is 1 means use itn, 0 means not use itn", false, 1, "int");
|
||||||
|
TCLAP::ValueArg<int> svs_itn_(
|
||||||
|
"", "svs-itn",
|
||||||
|
"svs-itn is 1 means use itn and punc, 0 means not use", false, 1, "int");
|
||||||
TCLAP::ValueArg<std::string> hotword_("", HOTWORD,
|
TCLAP::ValueArg<std::string> hotword_("", HOTWORD,
|
||||||
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||||
|
|
||||||
@ -378,6 +385,7 @@ int main(int argc, char* argv[]) {
|
|||||||
cmd.add(thread_num_);
|
cmd.add(thread_num_);
|
||||||
cmd.add(is_ssl_);
|
cmd.add(is_ssl_);
|
||||||
cmd.add(use_itn_);
|
cmd.add(use_itn_);
|
||||||
|
cmd.add(svs_itn_);
|
||||||
cmd.add(hotword_);
|
cmd.add(hotword_);
|
||||||
cmd.parse(argc, argv);
|
cmd.parse(argc, argv);
|
||||||
|
|
||||||
@ -387,6 +395,7 @@ int main(int argc, char* argv[]) {
|
|||||||
int threads_num = thread_num_.getValue();
|
int threads_num = thread_num_.getValue();
|
||||||
int is_ssl = is_ssl_.getValue();
|
int is_ssl = is_ssl_.getValue();
|
||||||
int use_itn = use_itn_.getValue();
|
int use_itn = use_itn_.getValue();
|
||||||
|
int svs_itn = svs_itn_.getValue();
|
||||||
|
|
||||||
std::vector<websocketpp::lib::thread> client_threads;
|
std::vector<websocketpp::lib::thread> client_threads;
|
||||||
std::string uri = "";
|
std::string uri = "";
|
||||||
@ -431,17 +440,17 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
int audio_fs = audio_fs_.getValue();
|
int audio_fs = audio_fs_.getValue();
|
||||||
for (size_t i = 0; i < threads_num; i++) {
|
for (size_t i = 0; i < threads_num; i++) {
|
||||||
client_threads.emplace_back([uri, wav_list, wav_ids, audio_fs, is_ssl, hws_map, use_itn]() {
|
client_threads.emplace_back([uri, wav_list, wav_ids, audio_fs, is_ssl, hws_map, use_itn, svs_itn]() {
|
||||||
if (is_ssl == 1) {
|
if (is_ssl == 1) {
|
||||||
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
|
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
|
||||||
|
|
||||||
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
||||||
|
|
||||||
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
|
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn, svs_itn);
|
||||||
} else {
|
} else {
|
||||||
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
||||||
|
|
||||||
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
|
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn, svs_itn);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,11 +55,11 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
TCLAP::ValueArg<std::string> offline_model_revision(
|
TCLAP::ValueArg<std::string> offline_model_revision(
|
||||||
"", "offline-model-revision", "ASR offline model revision", false,
|
"", "offline-model-revision", "ASR offline model revision", false,
|
||||||
"v2.0.4", "string");
|
"v2.0.5", "string");
|
||||||
|
|
||||||
TCLAP::ValueArg<std::string> online_model_revision(
|
TCLAP::ValueArg<std::string> online_model_revision(
|
||||||
"", "online-model-revision", "ASR online model revision", false,
|
"", "online-model-revision", "ASR online model revision", false,
|
||||||
"v2.0.4", "string");
|
"v2.0.5", "string");
|
||||||
|
|
||||||
TCLAP::ValueArg<std::string> quantize(
|
TCLAP::ValueArg<std::string> quantize(
|
||||||
"", QUANTIZE,
|
"", QUANTIZE,
|
||||||
@ -85,7 +85,7 @@ int main(int argc, char* argv[]) {
|
|||||||
"model_quant.onnx, punc.yaml",
|
"model_quant.onnx, punc.yaml",
|
||||||
false, "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx", "string");
|
false, "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx", "string");
|
||||||
TCLAP::ValueArg<std::string> punc_revision(
|
TCLAP::ValueArg<std::string> punc_revision(
|
||||||
"", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
|
"", "punc-revision", "PUNC model revision", false, "v2.0.5", "string");
|
||||||
TCLAP::ValueArg<std::string> punc_quant(
|
TCLAP::ValueArg<std::string> punc_quant(
|
||||||
"", PUNC_QUANT,
|
"", PUNC_QUANT,
|
||||||
"true (Default), load the model of model_quant.onnx in punc_dir. If "
|
"true (Default), load the model of model_quant.onnx in punc_dir. If "
|
||||||
@ -262,7 +262,7 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
||||||
if (found != std::string::npos) {
|
if (found != std::string::npos) {
|
||||||
model_path["offline-model-revision"]="v2.0.4";
|
model_path["offline-model-revision"]="v2.0.5";
|
||||||
}
|
}
|
||||||
|
|
||||||
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||||
@ -272,7 +272,7 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||||
if (found != std::string::npos) {
|
if (found != std::string::npos) {
|
||||||
model_path["model-revision"]="v2.0.4";
|
model_path["model-revision"]="v2.0.5";
|
||||||
s_itn_path="";
|
s_itn_path="";
|
||||||
s_lm_path="";
|
s_lm_path="";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,7 +50,7 @@ int main(int argc, char* argv[]) {
|
|||||||
TCLAP::ValueArg<std::string> model_revision(
|
TCLAP::ValueArg<std::string> model_revision(
|
||||||
"", "model-revision",
|
"", "model-revision",
|
||||||
"ASR model revision",
|
"ASR model revision",
|
||||||
false, "v2.0.4", "string");
|
false, "v2.0.5", "string");
|
||||||
TCLAP::ValueArg<std::string> quantize(
|
TCLAP::ValueArg<std::string> quantize(
|
||||||
"", QUANTIZE,
|
"", QUANTIZE,
|
||||||
"true (Default), load the model of model_quant.onnx in model_dir. If set "
|
"true (Default), load the model of model_quant.onnx in model_dir. If set "
|
||||||
@ -81,7 +81,7 @@ int main(int argc, char* argv[]) {
|
|||||||
TCLAP::ValueArg<std::string> punc_revision(
|
TCLAP::ValueArg<std::string> punc_revision(
|
||||||
"", "punc-revision",
|
"", "punc-revision",
|
||||||
"PUNC model revision",
|
"PUNC model revision",
|
||||||
false, "v2.0.4", "string");
|
false, "v2.0.5", "string");
|
||||||
TCLAP::ValueArg<std::string> punc_quant(
|
TCLAP::ValueArg<std::string> punc_quant(
|
||||||
"", PUNC_QUANT,
|
"", PUNC_QUANT,
|
||||||
"true (Default), load the model of model_quant.onnx in punc_dir. If set "
|
"true (Default), load the model of model_quant.onnx in punc_dir. If set "
|
||||||
@ -247,7 +247,7 @@ int main(int argc, char* argv[]) {
|
|||||||
// modify model-revision by model name
|
// modify model-revision by model name
|
||||||
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
||||||
if (found != std::string::npos) {
|
if (found != std::string::npos) {
|
||||||
model_path["model-revision"]="v2.0.4";
|
model_path["model-revision"]="v2.0.5";
|
||||||
}
|
}
|
||||||
|
|
||||||
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||||
@ -257,11 +257,22 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||||
if (found != std::string::npos) {
|
if (found != std::string::npos) {
|
||||||
model_path["model-revision"]="v2.0.4";
|
model_path["model-revision"]="v2.0.5";
|
||||||
s_itn_path="";
|
s_itn_path="";
|
||||||
s_lm_path="";
|
s_lm_path="";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
found = s_asr_path.find(MODEL_SVS);
|
||||||
|
if (found != std::string::npos) {
|
||||||
|
model_path["model-revision"]="v2.0.5";
|
||||||
|
s_itn_path="";
|
||||||
|
model_path[ITN_DIR]="";
|
||||||
|
s_lm_path="";
|
||||||
|
model_path[LM_DIR]="";
|
||||||
|
s_punc_path="";
|
||||||
|
model_path[PUNC_DIR]="";
|
||||||
|
}
|
||||||
|
|
||||||
if (use_gpu_){
|
if (use_gpu_){
|
||||||
model_type = "torchscript";
|
model_type = "torchscript";
|
||||||
if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
|
if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
|
||||||
|
|||||||
@ -67,7 +67,9 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
|||||||
bool itn,
|
bool itn,
|
||||||
int audio_fs,
|
int audio_fs,
|
||||||
std::string wav_format,
|
std::string wav_format,
|
||||||
FUNASR_DEC_HANDLE& decoder_handle) {
|
FUNASR_DEC_HANDLE& decoder_handle,
|
||||||
|
std::string svs_lang,
|
||||||
|
bool sys_itn) {
|
||||||
try {
|
try {
|
||||||
int num_samples = buffer.size(); // the size of the buf
|
int num_samples = buffer.size(); // the size of the buf
|
||||||
|
|
||||||
@ -78,7 +80,8 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
|||||||
try{
|
try{
|
||||||
FUNASR_RESULT Result = FunOfflineInferBuffer(
|
FUNASR_RESULT Result = FunOfflineInferBuffer(
|
||||||
asr_handle, buffer.data(), buffer.size(), RASR_NONE, nullptr,
|
asr_handle, buffer.data(), buffer.size(), RASR_NONE, nullptr,
|
||||||
hotwords_embedding, audio_fs, wav_format, itn, decoder_handle);
|
hotwords_embedding, audio_fs, wav_format, itn, decoder_handle,
|
||||||
|
svs_lang, sys_itn);
|
||||||
if (Result != nullptr){
|
if (Result != nullptr){
|
||||||
asr_result = FunASRGetResult(Result, 0); // get decode result
|
asr_result = FunASRGetResult(Result, 0); // get decode result
|
||||||
stamp_res = FunASRGetStamp(Result);
|
stamp_res = FunASRGetStamp(Result);
|
||||||
@ -162,6 +165,8 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
|||||||
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
||||||
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
|
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
|
||||||
data_msg->msg["is_eof"]=false;
|
data_msg->msg["is_eof"]=false;
|
||||||
|
data_msg->msg["svs_lang"]="auto";
|
||||||
|
data_msg->msg["svs_itn"]=true;
|
||||||
FUNASR_DEC_HANDLE decoder_handle =
|
FUNASR_DEC_HANDLE decoder_handle =
|
||||||
FunASRWfstDecoderInit(asr_handle, ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
|
FunASRWfstDecoderInit(asr_handle, ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
|
||||||
data_msg->decoder_handle = decoder_handle;
|
data_msg->decoder_handle = decoder_handle;
|
||||||
@ -357,6 +362,12 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
|||||||
if (jsonresult.contains("itn")) {
|
if (jsonresult.contains("itn")) {
|
||||||
msg_data->msg["itn"] = jsonresult["itn"];
|
msg_data->msg["itn"] = jsonresult["itn"];
|
||||||
}
|
}
|
||||||
|
if (jsonresult.contains("svs_lang")) {
|
||||||
|
msg_data->msg["svs_lang"] = jsonresult["svs_lang"];
|
||||||
|
}
|
||||||
|
if (jsonresult.contains("svs_itn")) {
|
||||||
|
msg_data->msg["svs_itn"] = jsonresult["svs_itn"];
|
||||||
|
}
|
||||||
if ((jsonresult["is_speaking"] == false ||
|
if ((jsonresult["is_speaking"] == false ||
|
||||||
jsonresult["is_finished"] == true) &&
|
jsonresult["is_finished"] == true) &&
|
||||||
msg_data->msg["is_eof"] != true &&
|
msg_data->msg["is_eof"] != true &&
|
||||||
@ -375,7 +386,9 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
|||||||
msg_data->msg["itn"],
|
msg_data->msg["itn"],
|
||||||
msg_data->msg["audio_fs"],
|
msg_data->msg["audio_fs"],
|
||||||
msg_data->msg["wav_format"],
|
msg_data->msg["wav_format"],
|
||||||
std::ref(msg_data->decoder_handle)));
|
std::ref(msg_data->decoder_handle),
|
||||||
|
msg_data->msg["svs_lang"],
|
||||||
|
msg_data->msg["svs_itn"]));
|
||||||
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -122,7 +122,9 @@ class WebSocketServer {
|
|||||||
bool itn,
|
bool itn,
|
||||||
int audio_fs,
|
int audio_fs,
|
||||||
std::string wav_format,
|
std::string wav_format,
|
||||||
FUNASR_DEC_HANDLE& decoder_handle);
|
FUNASR_DEC_HANDLE& decoder_handle,
|
||||||
|
std::string svs_lang,
|
||||||
|
bool sys_itn);
|
||||||
|
|
||||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
|
void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
|
||||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user