adapt to cpp runtime

This commit is contained in:
雾聪 2024-08-01 14:52:07 +08:00
parent dba846c352
commit 775f740239
20 changed files with 221 additions and 337 deletions

View File

@ -9,45 +9,45 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party)
SET(RELATION_SOURCE "../src/resample.cpp" "../src/util.cpp" "../src/alignedmem.cpp" "../src/encode_converter.cpp")
endif()
# add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
# add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
# add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
# add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
# add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
# add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
# add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
# add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
# add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
# add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
# target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
# target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
add_executable(funasr-onnx-offline-sv "funasr-onnx-offline-sv.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-sv PRIVATE "-Wl,--no-as-needed")

View File

@ -88,6 +88,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
int step = 1600*2;
bool is_final = false;
std::vector<std::vector<string>> punc_cache(2);
FunTpassOnlineReset(tpass_online_handle);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
@ -96,7 +97,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
is_final = false;
}
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, true, decoder_handle);
if (result)
{
FunASRFreeResult(result);
@ -137,6 +138,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
string tpass_res="";
string time_stamp_res="";
std::vector<std::vector<string>> punc_cache(2);
FunTpassOnlineReset(tpass_online_handle);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
if (sample_offset + step >= buff_len - 1) {
step = buff_len - sample_offset;
@ -146,7 +148,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
}
gettimeofday(&start, nullptr);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, true, decoder_handle);
gettimeofday(&end, nullptr);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@ -161,8 +163,16 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
}
string tpass_msg = FunASRGetTpassResult(result, 0);
tpass_res += tpass_msg;
if(tpass_msg != ""){
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] <<" offline results : "<<tpass_msg;
if (tpass_msg != "")
{
int speaker_idx = FunASRGetSvResult(result, 0);
if (speaker_idx != -1)
{
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " speaker_id : " << speaker_idx << " offline results : " << tpass_msg;
} else{
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] <<" offline results : "<<tpass_msg;
}
}
string stamp = FunASRGetStamp(result);
if(stamp !=""){
@ -223,6 +233,8 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string");
TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
@ -245,6 +257,8 @@ int main(int argc, char** argv)
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(sv_dir);
cmd.add(sv_quant);
cmd.add(lm_dir);
cmd.add(global_beam);
cmd.add(lattice_beam);
@ -266,8 +280,10 @@ int main(int argc, char** argv)
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(sv_dir, SV_DIR, model_path);
GetValue(sv_quant, SV_QUANT, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(asr_mode, ASR_MODE, model_path);
@ -285,9 +301,9 @@ int main(int argc, char** argv)
LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
exit(-1);
}
FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num);
FUNASR_HANDLE tpass_handle = FunTpassInit(model_path, thread_num);
if (!tpass_hanlde)
if (!tpass_handle)
{
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
@ -326,7 +342,7 @@ int main(int argc, char** argv)
return 0;
}
string line;
while(getline(in, line))
while (getline(in, line))
{
istringstream iss(line);
string column1, column2;
@ -349,7 +365,7 @@ int main(int argc, char** argv)
int rtf_threds = thread_num_.getValue();
for (int i = 0; i < rtf_threds; i++)
{
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
threads.emplace_back(thread(runReg, tpass_handle, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map));
}
@ -363,7 +379,7 @@ int main(int argc, char** argv)
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
FunTpassUninit(tpass_hanlde);
FunTpassUninit(tpass_handle);
return 0;
}

View File

@ -48,11 +48,11 @@ int main(int argc, char **argv)
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", true, "", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", true, "", "string");
TCLAP::ValueArg<std::string> sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string");
TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
@ -74,6 +74,7 @@ int main(int argc, char **argv)
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(sv_dir);
cmd.add(sv_quant);
cmd.add(lm_dir);
@ -81,7 +82,6 @@ int main(int argc, char **argv)
cmd.add(lattice_beam);
cmd.add(am_scale);
cmd.add(fst_inc_wts);
cmd.add(itn_dir);
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(asr_mode);
@ -135,8 +135,7 @@ int main(int argc, char **argv)
float glob_beam = 3.0f;
float lat_beam = 3.0f;
float am_sc = 10.0f;
if (lm_dir.isSet())
{
if (lm_dir.isSet()) {
glob_beam = global_beam.getValue();
lat_beam = lattice_beam.getValue();
am_sc = am_scale.getValue();
@ -144,9 +143,6 @@ int main(int argc, char **argv)
// init wfst decoder
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc);
// init sv-cam
bool sv_mode = 1; // use cam model for speaker verification
std::vector<std::vector<float>> voice_feats;
gettimeofday(&end, nullptr);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@ -165,8 +161,7 @@ int main(int argc, char **argv)
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if (is_target_file(wav_path_, "scp"))
{
if (is_target_file(wav_path_, "scp")) {
ifstream in(wav_path_);
if (!in.is_open())
{
@ -183,8 +178,7 @@ int main(int argc, char **argv)
wav_ids.emplace_back(column1);
}
in.close();
}
else
}else
{
wav_list.emplace_back(wav_path_);
wav_ids.emplace_back(default_id);
@ -239,6 +233,7 @@ int main(int argc, char **argv)
string tpass_res = "";
string time_stamp_res = "";
std::vector<std::vector<string>> punc_cache(2);
FunTpassOnlineReset(tpass_online_handle);
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset))
{
if (sample_offset + step >= buff_len - 1)
@ -252,9 +247,8 @@ int main(int argc, char **argv)
}
gettimeofday(&start, nullptr);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
voice_feats, sv_mode,
speech_buff + sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm",
(ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
(ASR_TYPE)asr_mode_, hotwords_embedding, true, true, decoder_handle);
gettimeofday(&end, nullptr);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@ -263,35 +257,36 @@ int main(int argc, char **argv)
{
string online_msg = FunASRGetResult(result, 0);
online_res += online_msg;
// if (online_msg != "")
// {
// LOG(INFO) << wav_id << " : " << online_msg;
// }
if (online_msg != "")
{
LOG(INFO) << wav_id << " : " << online_msg;
}
string tpass_msg = FunASRGetTpassResult(result, 0);
tpass_res += tpass_msg;
if (tpass_msg != "")
{
LOG(INFO) << wav_id << " offline results : " << tpass_msg;
int speaker_idx = FunASRGetSvResult(result, 0);
if (speaker_idx != -999)
if (speaker_idx != -1)
{
LOG(INFO) << "speaker_idx: " << speaker_idx;
LOG(INFO) << wav_id << " speaker_id : " << speaker_idx << " offline results : " << tpass_msg;
} else{
LOG(INFO) << wav_id << " offline results : " << tpass_msg;
}
}
string stamp = FunASRGetStamp(result);
if (stamp != "")
{
LOG(INFO) << wav_ids[i] << " time stamp : " << stamp;
if (time_stamp_res == "")
{
time_stamp_res += stamp;
}
else
{
time_stamp_res = time_stamp_res.erase(time_stamp_res.length() - 1) + "," + stamp.substr(1);
}
}
// string stamp = FunASRGetStamp(result);
// if (stamp != "")
// {
// LOG(INFO) << wav_ids[i] << " time stamp : " << stamp;
// if (time_stamp_res == "")
// {
// time_stamp_res += stamp;
// }
// else
// {
// time_stamp_res = time_stamp_res.erase(time_stamp_res.length() - 1) + "," + stamp.substr(1);
// }
// }
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}

View File

@ -92,16 +92,10 @@ int main(int argc, char *argv[])
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-offline-sv", ' ', "1.0");
// TCLAP::ValueArg<std::string> model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", true, "", "string");
// TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
// TCLAP::ValueArg<std::string> wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, ", true, "", "string");
// TCLAP::ValueArg<std::string> wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", true, "", "string");
// TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
TCLAP::ValueArg<std::string> model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", false, "/workspace/models/weights2/camplus_sv_zh-cn-16k-common-onnx", "string");
TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "flase", "string");
TCLAP::ValueArg<std::string> wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, ", false, "/home/wzp/project/FunASR/speaker1_a_cn_16k.wav", "string");
TCLAP::ValueArg<std::string> wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/home/wzp/project/FunASR/speaker1_b_cn_16k.wav", "string");
TCLAP::ValueArg<std::string> model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", true, "", "string");
TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
TCLAP::ValueArg<std::string> wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/workspace/models/iic/speech_campplus_sv_zh-cn_16k-common/examples/speaker1_a_cn_16k.wav", "string");
TCLAP::ValueArg<std::string> wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/workspace/models/iic/speech_campplus_sv_zh-cn_16k-common/examples/speaker1_b_cn_16k.wav", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
cmd.add(model_dir);

View File

@ -51,6 +51,8 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
TCLAP::ValueArg<std::string> sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string");
TCLAP::ValueArg<std::string> sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in sv_dir", false, "true", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
@ -71,6 +73,8 @@ int main(int argc, char** argv)
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(lm_dir);
cmd.add(sv_dir);
cmd.add(sv_quant);
cmd.add(global_beam);
cmd.add(lattice_beam);
cmd.add(am_scale);
@ -90,6 +94,8 @@ int main(int argc, char** argv)
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(sv_dir, SV_DIR, model_path);
GetValue(sv_quant, SV_QUANT, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(wav_path, WAV_PATH, model_path);
@ -190,10 +196,10 @@ int main(int argc, char** argv)
// int buff_len = audio.GetSpeechLen()*2;
// gettimeofday(&start, nullptr);
// FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
// FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, voice_feats, audio_fs.getValue(), "pcm", true, true, decoder_handle);
// For debug:end
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
std::vector<std::vector<float>> voice_feats;
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, voice_feats, audio_fs.getValue(), true, true, decoder_handle);
gettimeofday(&end, nullptr);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);

View File

@ -9,14 +9,10 @@ namespace funasr {
class SvModel {
public:
virtual ~SvModel(){};
virtual void InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num)=0;
virtual void InitSv(const std::string &model, const std::string &config, int thread_num)=0;
virtual std::vector<std::vector<float>> Infer(std::vector<float> &waves)=0;
float threshold=0.40;
};
SvModel *CreateSvModel(std::map<std::string, std::string>& model_path, int thread_num);
SvModel *CreateAndInferSvModel(std::map<std::string, std::string>& model_path, int thread_num);
// std::vector<std::vector<float>> InferSvModel(std::map<std::string, std::string>& model_path, int thread_num, std::vector<float>wave);
SvModel *CreateSVModel(std::map<std::string, std::string>& model_path, int thread_num);
} // namespace funasr

View File

@ -101,12 +101,12 @@ _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& mo
_FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
// buffer
_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len,
FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
int sampling_rate=16000, std::string wav_format="pcm", bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, std::vector<std::vector<float>> &voice_feats,
int sampling_rate=16000, std::string wav_format="pcm", bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode,
QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
int sampling_rate=16000, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, std::vector<std::vector<float>> &voice_feats,
int sampling_rate=16000, bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
//#if !defined(__APPLE__)
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE);
//#endif
@ -116,12 +116,12 @@ _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
//2passStream
_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size={5,10,5});
_FUNASRAPI void FunTpassOnlineReset(FUNASR_HANDLE tpass_online_handle);
// buffer
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle,
std::vector<std::vector<float>>& voice_feats, bool sv_mode, const char* sz_buf,
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true,
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);

View File

@ -7,6 +7,7 @@
#include "model.h"
#include "punc-model.h"
#include "vad-model.h"
#include "cam-sv-model.h"
#if !defined(__APPLE__)
#include "itn-model.h"
#endif
@ -20,17 +21,20 @@ class OfflineStream {
std::unique_ptr<VadModel> vad_handle= nullptr;
std::unique_ptr<Model> asr_handle= nullptr;
std::unique_ptr<PuncModel> punc_handle= nullptr;
std::unique_ptr<SvModel> sv_handle = nullptr;
#if !defined(__APPLE__)
std::unique_ptr<ITNModel> itn_handle = nullptr;
#endif
bool UseVad(){return use_vad;};
bool UsePunc(){return use_punc;};
bool UseITN(){return use_itn;};
bool UseSv(){return use_sv;};
private:
bool use_vad=false;
bool use_punc=false;
bool use_itn=false;
bool use_sv=false;
};
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);

View File

@ -14,7 +14,11 @@ class TpassOnlineStream {
std::unique_ptr<VadModel> vad_online_handle = nullptr;
std::unique_ptr<Model> asr_online_handle = nullptr;
//for sv-cam
std::vector<std::vector<float>> voice_feats;
};
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size);
void TpassOnlineCacheReset(void* tpass_online_stream);
} // namespace funasr
#endif

View File

@ -2,36 +2,12 @@
namespace funasr
{
SvModel *CreateSVModel(std::map<std::string, std::string> &model_path, int thread_num)
{
SvModel *mm;
mm = new CamPPlusSv();
string vad_model_path;
string vad_cmvn_path;
string vad_config_path;
vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if (model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true")
{
vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
}
vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), SV_CMVN_NAME);
vad_config_path = PathAppend(model_path.at(MODEL_DIR), SV_CONFIG_NAME);
mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
return mm;
}
// SvModel *CreateAndInferSvModel(std::map<std::string, std::string> &model_path, int thread_num, std::vector<float> wave)
SvModel *CreateAndInferSvModel(std::map<std::string, std::string> &model_path, int thread_num)
{
SvModel *mm;
mm = new CamPPlusSv();
string vad_model_path;
string vad_cmvn_path;
string vad_config_path;
vad_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME);
@ -39,35 +15,10 @@ namespace funasr
{
vad_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME);
}
vad_cmvn_path = PathAppend(model_path.at(SV_DIR), SV_CMVN_NAME);
vad_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME);
mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
mm->InitSv(vad_model_path, vad_config_path, thread_num);
return mm;
}
// std::vector<std::vector<float>> InferSvModel(std::map<std::string, std::string> &model_path, int thread_num, std::vector<float> wave)
// {
// SvModel *mm;
// mm = new CamPPlusSv();
// string vad_model_path;
// string vad_cmvn_path;
// string vad_config_path;
// vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
// if (model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true")
// {
// vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
// }
// vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), SV_CMVN_NAME);
// vad_config_path = PathAppend(model_path.at(MODEL_DIR), SV_CONFIG_NAME);
// mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
// std::vector<std::vector<float>> result = mm->Infer(wave);
// delete mm;
// return result;
// }
} // namespace funasr

View File

@ -7,29 +7,16 @@
#include "precomp.h"
#include <vector>
template <typename T>
void print_vec_shape(const std::vector<std::vector<T>> &data)
{
std::cout << "vec_shape= [" << data.size() << ", ";
if (!data.empty())
{
std::cout << data[0].size();
}
std::cout << "]" << std::endl;
}
namespace funasr
{
void CamPPlusSv::InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num)
void CamPPlusSv::InitSv(const std::string &model, const std::string &config, int thread_num)
{
session_options_.SetIntraOpNumThreads(thread_num);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options_.DisableCpuMemArena();
ReadModel(model.c_str());
// LoadCmvn(cmvn.c_str());
LoadConfigFromYaml(config.c_str());
}
void CamPPlusSv::LoadConfigFromYaml(const char *filename)
@ -104,8 +91,8 @@ namespace funasr
shape << j;
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
<< " dims=" << shape.str();
// LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
// << " dims=" << shape.str();
(*in_names)[i] = name.get();
name.release();
}
@ -125,8 +112,8 @@ namespace funasr
shape << j;
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
<< " dims=" << shape.str();
// LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
// << " dims=" << shape.str();
(*out_names)[i] = name.get();
name.release();
}
@ -205,116 +192,6 @@ namespace funasr
}
}
void CamPPlusSv::LoadCmvn(const char *filename)
{
try
{
using namespace std;
ifstream cmvn_stream(filename);
if (!cmvn_stream.is_open())
{
LOG(ERROR) << "Failed to open file: " << filename;
exit(-1);
}
string line;
while (getline(cmvn_stream, line))
{
istringstream iss(line);
vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
if (line_item[0] == "<AddShift>")
{
getline(cmvn_stream, line);
istringstream means_lines_stream(line);
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>")
{
for (int j = 3; j < means_lines.size() - 1; j++)
{
means_list_.push_back(stof(means_lines[j]));
}
continue;
}
}
else if (line_item[0] == "<Rescale>")
{
getline(cmvn_stream, line);
istringstream vars_lines_stream(line);
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>")
{
for (int j = 3; j < vars_lines.size() - 1; j++)
{
vars_list_.push_back(stof(vars_lines[j]));
}
continue;
}
}
}
}
catch (std::exception const &e)
{
LOG(ERROR) << "Error when load vad cmvn : " << e.what();
exit(-1);
}
}
void CamPPlusSv::LfrCmvn(std::vector<std::vector<float>> &vad_feats)
{
// std::vector<std::vector<float>> out_feats;
// int T = vad_feats.size();
// int T_lrf = ceil(1.0 * T / lfr_n);
// // Pad frames at start(copy first frame)
// for (int i = 0; i < (lfr_m - 1) / 2; i++)
// {
// vad_feats.insert(vad_feats.begin(), vad_feats[0]);
// }
// // Merge lfr_m frames as one,lfr_n frames per window
// T = T + (lfr_m - 1) / 2;
// std::vector<float> p;
// for (int i = 0; i < T_lrf; i++)
// {
// if (lfr_m <= T - i * lfr_n)
// {
// for (int j = 0; j < lfr_m; j++)
// {
// p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
// }
// out_feats.emplace_back(p);
// p.clear();
// }
// else
// {
// // Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame)
// int num_padding = lfr_m - (T - i * lfr_n);
// for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++)
// {
// p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
// }
// for (int j = 0; j < num_padding; j++)
// {
// p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
// }
// out_feats.emplace_back(p);
// p.clear();
// }
// }
// //Apply cmvn
std::vector<std::vector<float>> out_feats;
out_feats = vad_feats;
// only Apply cmvn
for (auto &out_feat : out_feats)
{
// for (int j = 0; j < means_list_.size(); j++) {
for (int j = 0; j < std::min(out_feats[0].size(), means_list_.size()); j++)
{
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
vad_feats = out_feats;
}
void CamPPlusSv::SubMean(std::vector<std::vector<float>> &voice_feats)
{
if (voice_feats.size() > 0)

View File

@ -19,7 +19,7 @@ namespace funasr
public:
CamPPlusSv();
~CamPPlusSv();
void InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num);
void InitSv(const std::string &model, const std::string &config, int thread_num);
std::vector<std::vector<float>> Infer(std::vector<float> &waves);
void Forward(
const std::vector<std::vector<float>> &chunk_feats,
@ -30,29 +30,19 @@ namespace funasr
Ort::SessionOptions session_options_;
std::vector<const char *> cam_in_names_;
std::vector<const char *> cam_out_names_;
std::vector<std::vector<float>> in_cache_;
knf::FbankOptions fbank_opts_;
std::vector<float> means_list_;
std::vector<float> vars_list_;
int sample_rate_ = MODEL_SAMPLE_RATE;
int lfr_m = VAD_LFR_M;
int lfr_n = VAD_LFR_N;
private:
void ReadModel(const char *cam_model);
void LoadConfigFromYaml(const char *filename);
static void GetInputOutputInfo(
const std::shared_ptr<Ort::Session> &session,
std::vector<const char *> *in_names, std::vector<const char *> *out_names);
void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
std::vector<float> &waves);
void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
void LoadCmvn(const char *filename);
void SubMean(std::vector<std::vector<float>>& voice_feats);
};

View File

@ -12,8 +12,8 @@ typedef struct
std::string stamp_sents;
std::string tpass_msg;
float snippet_time;
int speaker_idx = -999;
std::vector<float> speaker_emb;
std::vector<int> speaker_idxs;
std::vector<std::vector<float>> speaker_embs;
}FUNASR_RECOG_RESULT;
typedef struct

View File

@ -50,6 +50,11 @@
return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
}
_FUNASRAPI void FunTpassOnlineReset(FUNASR_HANDLE tpass_online_handle)
{
funasr::TpassOnlineCacheReset(tpass_online_handle);
}
// APIs for ASR Infer
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
{
@ -206,8 +211,9 @@
// APIs for Offline-stream Infer
_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len,
FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
int sampling_rate, std::string wav_format, bool itn, FUNASR_DEC_HANDLE dec_handle)
FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
std::vector<std::vector<float>> &voice_feats, int sampling_rate, std::string wav_format,
bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle)
{
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
if (!offline_stream)
@ -307,7 +313,7 @@
p_result->msg = punc_res;
}
#if !defined(__APPLE__)
if(offline_stream->UseITN() && itn){
if(offline_stream->UseITN() && use_itn){
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
if(!(p_result->stamp).empty()){
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
@ -319,13 +325,14 @@
}
#endif
if (!(p_result->stamp).empty()){
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp, p_result->speaker_idxs);
}
return p_result;
}
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback,
const std::vector<std::vector<float>> &hw_emb, int sampling_rate, bool itn, FUNASR_DEC_HANDLE dec_handle)
const std::vector<std::vector<float>> &hw_emb, std::vector<std::vector<float>> &voice_feats,
int sampling_rate, bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle)
{
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
if (!offline_stream)
@ -357,11 +364,13 @@
}
std::vector<int> index_vector={0};
int msg_idx = 0;
int svs_idx = 0;
if(offline_stream->UseVad()){
audio.CutSplit(offline_stream, index_vector);
}
std::vector<string> msgs(index_vector.size());
std::vector<float> msg_stimes(index_vector.size());
std::vector<int> svs(index_vector.size(), -1);
float** buff;
int* len;
@ -387,7 +396,30 @@
msg_idx++;
}else{
LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
}
}
}
if(offline_stream->UseSv() && use_sv){
for(int index=0; index<batch_in; index++){
if (len[index] > 1600){
if (voice_feats.size() < MAX_SPKS_NUM){
std::vector<float> wave(buff[index], buff[index]+len[index]);
std::vector<std::vector<float>> sv_result = offline_stream->sv_handle->Infer(wave);
float threshold = offline_stream->sv_handle->threshold;
int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats, threshold);
if(svs_idx < index_vector.size()){
svs[index_vector[svs_idx]] = speaker_idx;
svs_idx++;
}else{
LOG(ERROR) << "svs_idx: " << svs_idx <<" is out of range " << index_vector.size();
}
}else{
LOG(ERROR) << "Exceeding the maximum speaker limit!";
LOG(ERROR) << "speaker_idx: " << MAX_SPKS_NUM;
}
}
}
p_result->speaker_idxs = svs;
}
// release
@ -429,7 +461,7 @@
p_result->msg = punc_res;
}
#if !defined(__APPLE__)
if(offline_stream->UseITN() && itn){
if(offline_stream->UseITN() && use_itn){
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
if(!(p_result->stamp).empty()){
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
@ -440,8 +472,8 @@
p_result->msg = msg_itn;
}
#endif
if (!(p_result->stamp).empty()){
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
if (!(p_result->stamp).empty() || !(p_result->speaker_idxs.empty())){
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp, p_result->speaker_idxs);
}
return p_result;
}
@ -473,11 +505,10 @@
//#endif
// APIs for 2pass-stream Infer
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle,
std::vector<std::vector<float>>& voice_feats, bool sv_mode, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
int sampling_rate, std::string wav_format, ASR_TYPE mode,
const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
const std::vector<std::vector<float>> &hw_emb, bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@ -501,9 +532,6 @@
funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get();
if (!punc_online_handle)
return nullptr;
funasr::SvModel* sv_handle = (tpass_stream->sv_handle).get();
if (!sv_handle &&sv_mode)
return nullptr;
if(wav_format == "pcm" || wav_format == "PCM"){
@ -532,7 +560,7 @@
p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
// ITN
if(tpass_stream->UseITN() && itn){
if(tpass_stream->UseITN() && use_itn){
string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
p_result->tpass_msg = msg_itn;
}
@ -574,21 +602,21 @@
msg = msg_vec[0];
//sv-cam for Speaker verification
if(sv_mode&&frame->len>1600)//Filter audio clips less than 100ms
if(tpass_stream->UseSv() && use_sv &&frame->len>1600)//Filter audio clips less than 100ms
{
std::vector<std::vector<float>>& voice_feats = tpass_online_stream->voice_feats;
if (voice_feats.size()<MAX_SPKS_NUM)
{
std::vector<float>wave(frame->data,frame->data+frame->len);
std::vector<std::vector<float>>sv_result=sv_handle->Infer(wave);
// std::vector<std::vector<float>>sv_result = CamPPlusSvInfer(sv_handle, wave);
float threshold =sv_handle->threshold;
int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats,threshold);
p_result->speaker_idx = speaker_idx;
std::vector<float>wave(frame->data, frame->data+frame->len);
std::vector<std::vector<float>> sv_result = tpass_stream->sv_handle->Infer(wave);
float threshold = tpass_stream->sv_handle->threshold;
int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats, threshold);
// p_result->speaker_idx = speaker_idx;
}
else
{
LOG(ERROR)<<"Exceeding the maximum speaker limit!\n";
p_result->speaker_idx = MAX_SPKS_NUM;
// p_result->speaker_idx = MAX_SPKS_NUM;
}
}
//timestamp
@ -612,7 +640,7 @@
}
p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
if(tpass_stream->UseITN() && itn){
if(tpass_stream->UseITN() && use_itn){
string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
// TimestampSmooth
if(!(p_result->stamp).empty()){
@ -625,7 +653,7 @@
}
#endif
if (!(p_result->stamp).empty()){
p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp, p_result->speaker_idxs);
}
if(frame != nullptr){
delete frame;
@ -902,10 +930,7 @@
// APIs for CamPPlusSv Infer
_FUNASRAPI FUNASR_HANDLE CamPPlusSvInit(std::map<std::string, std::string>& model_path, int thread_num)
{
// funasr::SvModel *mm = funasr::CreateSvModel(model_path, thread_num);
// return mm;
// std::vector<float> wave;
funasr::SvModel* mm = funasr::CreateAndInferSvModel(model_path, thread_num);
funasr::SvModel* mm = funasr::CreateSVModel(model_path, thread_num);
return mm;
}
@ -935,7 +960,7 @@ _FUNASRAPI const int FunASRGetSvResult(FUNASR_RESULT result, int n_index)
if (!p_result)
return -1;
return p_result->speaker_idx;
return p_result->speaker_idxs[n_index];
}
_FUNASRAPI const std::vector<float> FunASRGetSvEmbResult(FUNASR_RESULT result, int n_index)
{
@ -944,5 +969,5 @@ _FUNASRAPI const std::vector<float> FunASRGetSvEmbResult(FUNASR_RESULT result, i
if (!p_result)
return speaker_emb;
return p_result->speaker_emb;
return p_result->speaker_embs[n_index];
}

View File

@ -138,6 +138,31 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
}
}
#endif
// sv cam
if (model_path.find(SV_DIR) != model_path.end() && model_path.at(SV_DIR) != "")
{
string sv_model_path;
string sv_config_path;
sv_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME);
if (model_path.find(SV_QUANT) != model_path.end() && model_path.at(SV_QUANT) == "true")
{
sv_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME);
}
sv_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME);
if (access(sv_model_path.c_str(), F_OK) != 0 ||
access(sv_config_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "CAMPlusPlus model file is not exist, skip load model.";
}else
{
sv_handle = make_unique<CamPPlusSv>();
sv_handle->InitSv(sv_model_path, sv_config_path, thread_num);
use_sv = true;
}
}
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)

View File

@ -18,6 +18,11 @@ TpassOnlineStream::TpassOnlineStream(TpassStream* tpass_stream, std::vector<int>
}
}
void TpassOnlineCacheReset(void* tpass_online_stream){
TpassOnlineStream* tpass_online_obj = (TpassOnlineStream*)tpass_online_stream;
tpass_online_obj->voice_feats.clear();
}
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size)
{
TpassOnlineStream *mm;

View File

@ -126,10 +126,9 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
#endif
// sv cam
if (model_path.find(SV_DIR) != model_path.end())
if (model_path.find(SV_DIR) != model_path.end() && model_path.at(SV_DIR) != "")
{
string sv_model_path;
string sv_cmvn_path;
string sv_config_path;
sv_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME);
@ -137,18 +136,16 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
{
sv_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME);
}
sv_cmvn_path = PathAppend(model_path.at(SV_DIR), SV_CMVN_NAME);
sv_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME);
if (access(sv_model_path.c_str(), F_OK) != 0 ||
// access(vad_cmvn_path.c_str(), F_OK) != 0 ||
access(sv_config_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "CAMPlusPlus model file is not exist, skip load vad model.";
LOG(INFO) << "CAMPlusPlus model file is not exist, skip load model.";
}
else
{
sv_handle = make_unique<CamPPlusSv>();
sv_handle->InitSv(sv_model_path, sv_cmvn_path, sv_config_path, thread_num);
sv_handle->InitSv(sv_model_path, sv_config_path, thread_num);
use_sv = true;
}
}

View File

@ -566,7 +566,7 @@ std::string TimestampSmooth(std::string &text, std::string &text_itn, std::strin
return timestamps_str;
}
std::string TimestampSentence(std::string &text, std::string &str_time){
std::string TimestampSentence(std::string &text, std::string &str_time, std::vector<int> speaker_idxs){
std::vector<std::string> characters;
funasr::TimestampSplitChiEngCharacters(text, characters);
vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);

View File

@ -47,7 +47,7 @@ void TimestampSplitChiEngCharacters(const std::string &input_str,
std::vector<std::string> &characters);
std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty=true);
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
std::string TimestampSentence(std::string &text, std::string &str_time);
std::string TimestampSentence(std::string &text, std::string &str_time, std::vector<int> speaker_idxs);
std::vector<std::string> split(const std::string &s, char delim);
template<typename T>
@ -71,6 +71,5 @@ void ExtractHws(string hws_file, unordered_map<string, int> &hws_map);
void ExtractHws(string hws_file, unordered_map<string, int> &hws_map, string& nn_hotwords_);
float CosineSimilarity(const std::vector<float> &emb1, const std::vector<float> &emb2);
int GetSpeakersID(const std::vector<float> &emb1, std::vector<std::vector<float>> &emb_list, float threshold = 0.40);
} // namespace funasr
#endif

View File

@ -140,7 +140,7 @@ void WebSocketServer::do_decoder(
subvector.data(), subvector.size(),
punc_cache, false, audio_fs,
wav_format, (ASR_TYPE)asr_mode_,
hotwords_embedding, itn, decoder_handle);
hotwords_embedding, itn, true, decoder_handle);
} else {
scoped_lock guard(thread_lock);
@ -177,7 +177,7 @@ void WebSocketServer::do_decoder(
buffer.data(), buffer.size(), punc_cache,
is_final, audio_fs,
wav_format, (ASR_TYPE)asr_mode_,
hotwords_embedding, itn, decoder_handle);
hotwords_embedding, itn, true, decoder_handle);
} else {
scoped_lock guard(thread_lock);
msg["access_num"]=(int)msg["access_num"]-1;