add mic for funasr-wss-client-2pass

This commit is contained in:
雾聪 2023-09-07 14:23:58 +08:00
parent b0b4d8a45d
commit b26d3de5fa
4 changed files with 288 additions and 64 deletions

View File

@ -7,6 +7,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
option(ENABLE_PORTAUDIO "Whether to build websocket server" ON)
if(ENABLE_WEBSOCKET)
# cmake_policy(SET CMP0135 NEW)
@ -38,6 +39,37 @@ if(ENABLE_WEBSOCKET)
endif()
if(ENABLE_PORTAUDIO)
include(FetchContent)
set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
FetchContent_Declare(portaudio
URL
${portaudio_URL}
${portaudio_URL2}
URL_HASH ${portaudio_HASH}
)
FetchContent_GetProperties(portaudio)
if(NOT portaudio_POPULATED)
message(STATUS "Downloading portaudio from ${portaudio_URL}")
FetchContent_Populate(portaudio)
endif()
message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
if(NOT WIN32)
target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
else()
install(TARGETS portaudio DESTINATION ..)
endif()
endif()
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
@ -61,9 +93,9 @@ find_package(OpenSSL REQUIRED)
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
add_executable(funasr-wss-client "funasr-wss-client.cpp")
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp")
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp")
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)

View File

@ -17,6 +17,7 @@
#define ASIO_STANDALONE 1
#include <glog/logging.h>
#include "portaudio.h"
#include <atomic>
#include <fstream>
@ -30,6 +31,7 @@
#include "audio.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
#include "microphone.h"
/**
* Define a semi-cross platform helper method that waits/sleeps for a bit.
@ -123,7 +125,6 @@ class WebsocketClient {
if (ec) {
LOG(ERROR) << "Error closing connection " << ec.message();
}
}
}
}
@ -131,7 +132,7 @@ class WebsocketClient {
// This method will block until the connection is complete
void run(const std::string& uri, const std::vector<string>& wav_list,
const std::vector<string>& wav_ids, std::string asr_mode,
std::vector<int> chunk_size) {
std::vector<int> chunk_size, bool is_record=false) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@ -152,8 +153,11 @@ class WebsocketClient {
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
if(is_record){
send_rec_data(asr_mode, chunk_size);
}else{
send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
}
WaitABit();
@ -264,16 +268,11 @@ class WebsocketClient {
send_block = len - offset;
}
m_client.send(m_hdl, iArray + offset, send_block * sizeof(short),
websocketpp::frame::opcode::binary, ec);
websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
LOG(INFO) << "sended data len=" << len * sizeof(short);
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@ -300,11 +299,6 @@ class WebsocketClient {
}
LOG(INFO) << "sended data len=" << len;
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@ -317,6 +311,137 @@ class WebsocketClient {
ec);
WaitABit();
}
static int RecordCallback(const void* inputBuffer, void* outputBuffer,
unsigned long framesPerBuffer, const PaStreamCallbackTimeInfo* timeInfo,
PaStreamCallbackFlags statusFlags, void* userData)
{
std::vector<float>* buffer = static_cast<std::vector<float>*>(userData);
const float* input = static_cast<const float*>(inputBuffer);
for (unsigned int i = 0; i < framesPerBuffer; i++)
{
buffer->push_back(input[i]);
}
return paContinue;
}
void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector) {
// first message
bool wait = false;
while (1) {
{
scoped_lock guard(m_lock);
// If the connection has been closed, stop generating data
if (m_done) {
break;
}
// If the connection hasn't been opened yet wait a bit and retry
if (!m_open) {
wait = true;
} else {
break;
}
}
if (wait) {
// LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
}
websocketpp::lib::error_code ec;
nlohmann::json jsonbegin;
nlohmann::json chunk_size = nlohmann::json::array();
chunk_size.push_back(chunk_vector[0]);
chunk_size.push_back(chunk_vector[1]);
chunk_size.push_back(chunk_vector[2]);
jsonbegin["mode"] = asr_mode;
jsonbegin["chunk_size"] = chunk_size;
jsonbegin["wav_name"] = "record";
jsonbegin["wav_format"] = "pcm";
jsonbegin["is_speaking"] = true;
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
ec);
// mic
Microphone mic;
PaDeviceIndex num_devices = Pa_GetDeviceCount();
LOG(INFO) << "Num devices: " << num_devices;
PaStreamParameters param;
param.device = Pa_GetDefaultInputDevice();
if (param.device == paNoDevice) {
LOG(INFO) << "No default input device found";
exit(EXIT_FAILURE);
}
LOG(INFO) << "Use default device: " << param.device;
const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
LOG(INFO) << " Name: " << info->name;
LOG(INFO) << " Max input channels: " << info->maxInputChannels;
param.channelCount = 1;
param.sampleFormat = paFloat32;
param.suggestedLatency = info->defaultLowInputLatency;
param.hostApiSpecificStreamInfo = nullptr;
float sample_rate = 16000;
PaStream *stream;
std::vector<float> buffer;
PaError err =
Pa_OpenStream(&stream, &param, nullptr, /* &outputParameters, */
sample_rate,
0, // frames per buffer
paClipOff, // we won't output out of range samples
// so don't bother clipping them
RecordCallback, &buffer);
if (err != paNoError) {
LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
exit(EXIT_FAILURE);
}
err = Pa_StartStream(stream);
LOG(INFO) << "Started: ";
if (err != paNoError) {
LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
exit(EXIT_FAILURE);
}
while(true){
int len = buffer.size();
short* iArray = new short[len];
for (size_t i = 0; i < len; ++i) {
iArray[i] = (short)(buffer[i] * 32768);
}
m_client.send(m_hdl, iArray, len * sizeof(short),
websocketpp::frame::opcode::binary, ec);
buffer.clear();
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
}
Pa_Sleep(20); // sleep for 20ms
}
nlohmann::json jsonresult;
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
ec);
err = Pa_CloseStream(stream);
if (err != paNoError) {
LOG(INFO) << "portaudio error: " << Pa_GetErrorText(err);
exit(EXIT_FAILURE);
}
}
websocketpp::client<T> m_client;
private:
@ -331,7 +456,7 @@ int main(int argc, char* argv[]) {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
TCLAP::CmdLine cmd("funasr-wss-client-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
"127.0.0.1", "string");
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095",
@ -340,7 +465,11 @@ int main(int argc, char* argv[]) {
"", "wav-path",
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
"asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
true, "", "string");
false, "", "string");
TCLAP::ValueArg<int> record_(
"", "record",
"record is 1 means use record", false, 0,
"int");
TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass",
false, "2pass", "string");
TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size",
@ -357,6 +486,7 @@ int main(int argc, char* argv[]) {
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(asr_mode_);
cmd.add(record_);
cmd.add(chunk_size_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
@ -382,6 +512,7 @@ int main(int argc, char* argv[]) {
int threads_num = thread_num_.getValue();
int is_ssl = is_ssl_.getValue();
int is_record = record_.getValue();
std::string uri = "";
if (is_ssl == 1) {
@ -390,60 +521,78 @@ int main(int argc, char* argv[]) {
uri = "ws://" + server_ip + ":" + port;
}
// read wav_path
std::vector<string> wav_list;
std::vector<string> wav_ids;
string default_id = "wav_default_id";
if (IsTargetFile(wav_path, "scp")) {
ifstream in(wav_path);
if (!in.is_open()) {
printf("Failed to open scp file");
return 0;
}
string line;
while (getline(in, line)) {
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
} else {
wav_list.emplace_back(wav_path);
wav_ids.emplace_back(default_id);
}
for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
std::vector<websocketpp::lib::thread> client_threads;
for (size_t i = 0; i < threads_num; i++) {
if (wav_i + i >= wav_list.size()) {
break;
}
if(is_record == 1){
std::vector<string> tmp_wav_list;
std::vector<string> tmp_wav_ids;
tmp_wav_list.emplace_back(wav_list[wav_i + i]);
tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
client_threads.emplace_back(
[uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
}
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
}
});
}else{
// read wav_path
std::vector<string> wav_list;
std::vector<string> wav_ids;
string default_id = "wav_default_id";
if (IsTargetFile(wav_path, "scp")) {
ifstream in(wav_path);
if (!in.is_open()) {
printf("Failed to open scp file");
return 0;
}
string line;
while (getline(in, line)) {
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
wav_list.emplace_back(column2);
wav_ids.emplace_back(column1);
}
in.close();
} else {
wav_list.emplace_back(wav_path);
wav_ids.emplace_back(default_id);
}
for (auto& t : client_threads) {
t.join();
for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
std::vector<websocketpp::lib::thread> client_threads;
for (size_t i = 0; i < threads_num; i++) {
if (wav_i + i >= wav_list.size()) {
break;
}
std::vector<string> tmp_wav_list;
std::vector<string> tmp_wav_ids;
tmp_wav_list.emplace_back(wav_list[wav_i + i]);
tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
client_threads.emplace_back(
[uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
}
});
}
for (auto& t : client_threads) {
t.join();
}
}
}
}

View File

@ -0,0 +1,27 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
#include "microphone.h"
#include <stdio.h>
#include <stdlib.h>
#include "portaudio.h" // NOLINT
Microphone::Microphone() {
PaError err = Pa_Initialize();
if (err != paNoError) {
LOG(ERROR)<<"portaudio error: " << Pa_GetErrorText(err);
exit(-1);
}
}
Microphone::~Microphone() {
PaError err = Pa_Terminate();
if (err != paNoError) {
LOG(ERROR)<<"portaudio error: " << Pa_GetErrorText(err);
exit(-1);
}
}

View File

@ -0,0 +1,16 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
#ifndef WEBSOCKET_MICROPHONE_H_
#define WEBSOCKET_MICROPHONE_H_
#include <glog/logging.h>
class Microphone {
public:
Microphone();
~Microphone();
};
#endif // WEBSOCKET_MICROPHONE_H_