diff --git a/runtime/onnxruntime/bin/CMakeLists.txt b/runtime/onnxruntime/bin/CMakeLists.txt index c825ad50a..a9a87bc63 100644 --- a/runtime/onnxruntime/bin/CMakeLists.txt +++ b/runtime/onnxruntime/bin/CMakeLists.txt @@ -9,45 +9,45 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) SET(RELATION_SOURCE "../src/resample.cpp" "../src/util.cpp" "../src/alignedmem.cpp" "../src/encode_converter.cpp") endif() -# add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-offline PUBLIC funasr) +add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-offline PUBLIC funasr) -# add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr) +add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr) -# add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-online-vad PUBLIC funasr) +add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-online-vad PUBLIC funasr) -# add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-online-asr PUBLIC funasr) +add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-online-asr PUBLIC funasr) -# add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) +add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) -# add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-online-punc PUBLIC funasr) +add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-online-punc PUBLIC funasr) -# add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) +add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) -# add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-2pass PUBLIC funasr) +add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-2pass PUBLIC funasr) -# add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr) +add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr) -# add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE}) -# target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed") -# target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr) +add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE}) +target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed") +target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr) add_executable(funasr-onnx-offline-sv "funasr-onnx-offline-sv.cpp" ${RELATION_SOURCE}) target_link_options(funasr-onnx-offline-sv PRIVATE "-Wl,--no-as-needed") diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp index d4abacd1f..e37ef2092 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp @@ -88,6 +88,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector chunk_size, vector> punc_cache(2); + FunTpassOnlineReset(tpass_online_handle); for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) { if (sample_offset + step >= buff_len - 1) { step = buff_len - sample_offset; @@ -96,7 +97,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector chunk_size, vector chunk_size, vector> punc_cache(2); + FunTpassOnlineReset(tpass_online_handle); for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) { if (sample_offset + step >= buff_len - 1) { step = buff_len - sample_offset; @@ -146,7 +148,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector chunk_size, vector chunk_size, vector vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); + TCLAP::ValueArg sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string"); + TCLAP::ValueArg sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); TCLAP::ValueArg itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); @@ -245,6 +257,8 @@ int main(int argc, char** argv) cmd.add(punc_dir); cmd.add(punc_quant); cmd.add(itn_dir); + cmd.add(sv_dir); + cmd.add(sv_quant); cmd.add(lm_dir); cmd.add(global_beam); cmd.add(lattice_beam); @@ -266,8 +280,10 @@ int main(int argc, char** argv) GetValue(vad_quant, VAD_QUANT, model_path); GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); - GetValue(itn_dir, ITN_DIR, model_path); + GetValue(sv_dir, SV_DIR, model_path); + GetValue(sv_quant, SV_QUANT, model_path); GetValue(lm_dir, LM_DIR, model_path); + GetValue(itn_dir, ITN_DIR, model_path); GetValue(wav_path, WAV_PATH, model_path); GetValue(asr_mode, ASR_MODE, model_path); @@ -285,9 +301,9 @@ int main(int argc, char** argv) LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE]; exit(-1); } - FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num); + FUNASR_HANDLE tpass_handle = FunTpassInit(model_path, thread_num); - if (!tpass_hanlde) + if (!tpass_handle) { LOG(ERROR) << "FunTpassInit init failed"; exit(-1); @@ -326,7 +342,7 @@ int main(int argc, char** argv) return 0; } string line; - while(getline(in, line)) + while (getline(in, line)) { istringstream iss(line); string column1, column2; @@ -349,7 +365,7 @@ int main(int argc, char** argv) int rtf_threds = thread_num_.getValue(); for (int i = 0; i < rtf_threds; i++) { - threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_, + threads.emplace_back(thread(runReg, tpass_handle, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_, glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map)); } @@ -363,7 +379,7 @@ int main(int argc, char** argv) LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000); LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000)); - FunTpassUninit(tpass_hanlde); + FunTpassUninit(tpass_handle); return 0; } diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp index 1865ebc7d..119ef3d84 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp @@ -48,11 +48,11 @@ int main(int argc, char **argv) TCLAP::ValueArg offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); TCLAP::ValueArg online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); TCLAP::ValueArg quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); - TCLAP::ValueArg vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); + TCLAP::ValueArg vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); TCLAP::ValueArg vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); - TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", true, "", "string"); + TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); - TCLAP::ValueArg sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", true, "", "string"); + TCLAP::ValueArg sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string"); TCLAP::ValueArg sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); TCLAP::ValueArg itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); @@ -74,6 +74,7 @@ int main(int argc, char **argv) cmd.add(vad_quant); cmd.add(punc_dir); cmd.add(punc_quant); + cmd.add(itn_dir); cmd.add(sv_dir); cmd.add(sv_quant); cmd.add(lm_dir); @@ -81,7 +82,6 @@ int main(int argc, char **argv) cmd.add(lattice_beam); cmd.add(am_scale); cmd.add(fst_inc_wts); - cmd.add(itn_dir); cmd.add(wav_path); cmd.add(audio_fs); cmd.add(asr_mode); @@ -135,8 +135,7 @@ int main(int argc, char **argv) float glob_beam = 3.0f; float lat_beam = 3.0f; float am_sc = 10.0f; - if (lm_dir.isSet()) - { + if (lm_dir.isSet()) { glob_beam = global_beam.getValue(); lat_beam = lattice_beam.getValue(); am_sc = am_scale.getValue(); @@ -144,9 +143,6 @@ int main(int argc, char **argv) // init wfst decoder FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc); - // init sv-cam - bool sv_mode = 1; // use cam model for speaker verification - std::vector> voice_feats; gettimeofday(&end, nullptr); long seconds = (end.tv_sec - start.tv_sec); long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); @@ -165,8 +161,7 @@ int main(int argc, char **argv) string default_id = "wav_default_id"; string wav_path_ = model_path.at(WAV_PATH); - if (is_target_file(wav_path_, "scp")) - { + if (is_target_file(wav_path_, "scp")) { ifstream in(wav_path_); if (!in.is_open()) { @@ -183,8 +178,7 @@ int main(int argc, char **argv) wav_ids.emplace_back(column1); } in.close(); - } - else + }else { wav_list.emplace_back(wav_path_); wav_ids.emplace_back(default_id); @@ -239,6 +233,7 @@ int main(int argc, char **argv) string tpass_res = ""; string time_stamp_res = ""; std::vector> punc_cache(2); + FunTpassOnlineReset(tpass_online_handle); for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) { if (sample_offset + step >= buff_len - 1) @@ -252,9 +247,8 @@ int main(int argc, char **argv) } gettimeofday(&start, nullptr); FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, - voice_feats, sv_mode, speech_buff + sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", - (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle); + (ASR_TYPE)asr_mode_, hotwords_embedding, true, true, decoder_handle); gettimeofday(&end, nullptr); seconds = (end.tv_sec - start.tv_sec); taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); @@ -263,35 +257,36 @@ int main(int argc, char **argv) { string online_msg = FunASRGetResult(result, 0); online_res += online_msg; - // if (online_msg != "") - // { - // LOG(INFO) << wav_id << " : " << online_msg; - // } + if (online_msg != "") + { + LOG(INFO) << wav_id << " : " << online_msg; + } string tpass_msg = FunASRGetTpassResult(result, 0); tpass_res += tpass_msg; if (tpass_msg != "") { - LOG(INFO) << wav_id << " offline results : " << tpass_msg; int speaker_idx = FunASRGetSvResult(result, 0); - if (speaker_idx != -999) + if (speaker_idx != -1) { - LOG(INFO) << "speaker_idx: " << speaker_idx; + LOG(INFO) << wav_id << " speaker_id : " << speaker_idx << " offline results : " << tpass_msg; + } else{ + LOG(INFO) << wav_id << " offline results : " << tpass_msg; + } + } + string stamp = FunASRGetStamp(result); + if (stamp != "") + { + LOG(INFO) << wav_ids[i] << " time stamp : " << stamp; + if (time_stamp_res == "") + { + time_stamp_res += stamp; + } + else + { + time_stamp_res = time_stamp_res.erase(time_stamp_res.length() - 1) + "," + stamp.substr(1); } } - // string stamp = FunASRGetStamp(result); - // if (stamp != "") - // { - // LOG(INFO) << wav_ids[i] << " time stamp : " << stamp; - // if (time_stamp_res == "") - // { - // time_stamp_res += stamp; - // } - // else - // { - // time_stamp_res = time_stamp_res.erase(time_stamp_res.length() - 1) + "," + stamp.substr(1); - // } - // } snippet_time += FunASRGetRetSnippetTime(result); FunASRFreeResult(result); } diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-sv.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-sv.cpp index aa02a7bf6..9f8a51031 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-offline-sv.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-offline-sv.cpp @@ -92,16 +92,10 @@ int main(int argc, char *argv[]) google::InitGoogleLogging(argv[0]); FLAGS_logtostderr = true; TCLAP::CmdLine cmd("funasr-onnx-offline-sv", ' ', "1.0"); - // TCLAP::ValueArg model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", true, "", "string"); - // TCLAP::ValueArg sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); - // TCLAP::ValueArg wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, ", true, "", "string"); - // TCLAP::ValueArg wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", true, "", "string"); - // TCLAP::ValueArg onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); - - TCLAP::ValueArg model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", false, "/workspace/models/weights2/camplus_sv_zh-cn-16k-common-onnx", "string"); - TCLAP::ValueArg sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "flase", "string"); - TCLAP::ValueArg wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, ", false, "/home/wzp/project/FunASR/speaker1_a_cn_16k.wav", "string"); - TCLAP::ValueArg wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/home/wzp/project/FunASR/speaker1_b_cn_16k.wav", "string"); + TCLAP::ValueArg model_dir("", SV_DIR, "the cam model path, which contains model.onnx, cam.yaml", true, "", "string"); + TCLAP::ValueArg sv_quant("", SV_QUANT, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); + TCLAP::ValueArg wav_file1("", "wav_file1", "the input could be: wav_path, e.g.: asr_example1.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/workspace/models/iic/speech_campplus_sv_zh-cn_16k-common/examples/speaker1_a_cn_16k.wav", "string"); + TCLAP::ValueArg wav_file2("", "wav_file2", "the input could be: wav_path, e.g.: asr_example2.wav; pcm_path, e.g.: asr_example.pcm; wav.scp,", false, "/workspace/models/iic/speech_campplus_sv_zh-cn_16k-common/examples/speaker1_b_cn_16k.wav", "string"); TCLAP::ValueArg onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); cmd.add(model_dir); diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp index 87928f59f..61b80cb8b 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp @@ -51,6 +51,8 @@ int main(int argc, char** argv) TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); + TCLAP::ValueArg sv_dir("", SV_DIR, "the sv online model path, which contains model.onnx, config.yaml", false, "", "string"); + TCLAP::ValueArg sv_quant("", SV_QUANT, "true (Default), load the model of model.onnx in sv_dir. If set true, load the model of model_quant.onnx in sv_dir", false, "true", "string"); TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); TCLAP::ValueArg lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); TCLAP::ValueArg am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); @@ -71,6 +73,8 @@ int main(int argc, char** argv) cmd.add(punc_quant); cmd.add(itn_dir); cmd.add(lm_dir); + cmd.add(sv_dir); + cmd.add(sv_quant); cmd.add(global_beam); cmd.add(lattice_beam); cmd.add(am_scale); @@ -90,6 +94,8 @@ int main(int argc, char** argv) GetValue(vad_quant, VAD_QUANT, model_path); GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); + GetValue(sv_dir, SV_DIR, model_path); + GetValue(sv_quant, SV_QUANT, model_path); GetValue(itn_dir, ITN_DIR, model_path); GetValue(lm_dir, LM_DIR, model_path); GetValue(wav_path, WAV_PATH, model_path); @@ -190,10 +196,10 @@ int main(int argc, char** argv) // int buff_len = audio.GetSpeechLen()*2; // gettimeofday(&start, nullptr); - // FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle); + // FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, voice_feats, audio_fs.getValue(), "pcm", true, true, decoder_handle); // For debug:end - - FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle); + std::vector> voice_feats; + FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, voice_feats, audio_fs.getValue(), true, true, decoder_handle); gettimeofday(&end, nullptr); seconds = (end.tv_sec - start.tv_sec); taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); diff --git a/runtime/onnxruntime/include/cam-sv-model.h b/runtime/onnxruntime/include/cam-sv-model.h index 6ff7a0cd1..947642167 100644 --- a/runtime/onnxruntime/include/cam-sv-model.h +++ b/runtime/onnxruntime/include/cam-sv-model.h @@ -9,14 +9,10 @@ namespace funasr { class SvModel { public: virtual ~SvModel(){}; - virtual void InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num)=0; + virtual void InitSv(const std::string &model, const std::string &config, int thread_num)=0; virtual std::vector> Infer(std::vector &waves)=0; float threshold=0.40; }; -SvModel *CreateSvModel(std::map& model_path, int thread_num); - -SvModel *CreateAndInferSvModel(std::map& model_path, int thread_num); -// std::vector> InferSvModel(std::map& model_path, int thread_num, std::vectorwave); - +SvModel *CreateSVModel(std::map& model_path, int thread_num); } // namespace funasr diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h index b1f1ecf34..9d8055404 100644 --- a/runtime/onnxruntime/include/funasrruntime.h +++ b/runtime/onnxruntime/include/funasrruntime.h @@ -101,12 +101,12 @@ _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& mo _FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr); // buffer _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, - FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector> &hw_emb, - int sampling_rate=16000, std::string wav_format="pcm", bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr); + FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector> &hw_emb, std::vector> &voice_feats, + int sampling_rate=16000, std::string wav_format="pcm", bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr); // file, support wav & pcm _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, - QM_CALLBACK fn_callback, const std::vector> &hw_emb, - int sampling_rate=16000, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr); + QM_CALLBACK fn_callback, const std::vector> &hw_emb, std::vector> &voice_feats, + int sampling_rate=16000, bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr); //#if !defined(__APPLE__) _FUNASRAPI const std::vector> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE); //#endif @@ -116,12 +116,12 @@ _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle); //2passStream _FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map& model_path, int thread_num); _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector chunk_size={5,10,5}); +_FUNASRAPI void FunTpassOnlineReset(FUNASR_HANDLE tpass_online_handle); // buffer -_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, - std::vector>& voice_feats, bool sv_mode, const char* sz_buf, +_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector> &punc_cache, bool input_finished=true, - int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, - const std::vector> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr); + int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, + const std::vector> &hw_emb={{0.0}}, bool use_itn=true, bool use_sv=true, FUNASR_DEC_HANDLE dec_handle=nullptr); _FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle); _FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle); diff --git a/runtime/onnxruntime/include/offline-stream.h b/runtime/onnxruntime/include/offline-stream.h index 568629db0..1e25707cf 100644 --- a/runtime/onnxruntime/include/offline-stream.h +++ b/runtime/onnxruntime/include/offline-stream.h @@ -7,6 +7,7 @@ #include "model.h" #include "punc-model.h" #include "vad-model.h" +#include "cam-sv-model.h" #if !defined(__APPLE__) #include "itn-model.h" #endif @@ -20,17 +21,20 @@ class OfflineStream { std::unique_ptr vad_handle= nullptr; std::unique_ptr asr_handle= nullptr; std::unique_ptr punc_handle= nullptr; + std::unique_ptr sv_handle = nullptr; #if !defined(__APPLE__) std::unique_ptr itn_handle = nullptr; #endif bool UseVad(){return use_vad;}; bool UsePunc(){return use_punc;}; bool UseITN(){return use_itn;}; + bool UseSv(){return use_sv;}; private: bool use_vad=false; bool use_punc=false; bool use_itn=false; + bool use_sv=false; }; OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1); diff --git a/runtime/onnxruntime/include/tpass-online-stream.h b/runtime/onnxruntime/include/tpass-online-stream.h index e092880b3..97baf262b 100644 --- a/runtime/onnxruntime/include/tpass-online-stream.h +++ b/runtime/onnxruntime/include/tpass-online-stream.h @@ -14,7 +14,11 @@ class TpassOnlineStream { std::unique_ptr vad_online_handle = nullptr; std::unique_ptr asr_online_handle = nullptr; + + //for sv-cam + std::vector> voice_feats; }; TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector chunk_size); +void TpassOnlineCacheReset(void* tpass_online_stream); } // namespace funasr #endif diff --git a/runtime/onnxruntime/src/cam-sv-model.cpp b/runtime/onnxruntime/src/cam-sv-model.cpp index 3ec03e245..77907f723 100644 --- a/runtime/onnxruntime/src/cam-sv-model.cpp +++ b/runtime/onnxruntime/src/cam-sv-model.cpp @@ -2,36 +2,12 @@ namespace funasr { - SvModel *CreateSVModel(std::map &model_path, int thread_num) { SvModel *mm; mm = new CamPPlusSv(); string vad_model_path; - string vad_cmvn_path; - string vad_config_path; - - vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); - if (model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true") - { - vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); - } - vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), SV_CMVN_NAME); - vad_config_path = PathAppend(model_path.at(MODEL_DIR), SV_CONFIG_NAME); - - mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num); - return mm; - } - - // SvModel *CreateAndInferSvModel(std::map &model_path, int thread_num, std::vector wave) - SvModel *CreateAndInferSvModel(std::map &model_path, int thread_num) - { - SvModel *mm; - mm = new CamPPlusSv(); - - string vad_model_path; - string vad_cmvn_path; string vad_config_path; vad_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME); @@ -39,35 +15,10 @@ namespace funasr { vad_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME); } - vad_cmvn_path = PathAppend(model_path.at(SV_DIR), SV_CMVN_NAME); vad_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME); - mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num); - + mm->InitSv(vad_model_path, vad_config_path, thread_num); return mm; } - // std::vector> InferSvModel(std::map &model_path, int thread_num, std::vector wave) - // { - // SvModel *mm; - // mm = new CamPPlusSv(); - - // string vad_model_path; - // string vad_cmvn_path; - // string vad_config_path; - - // vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); - // if (model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true") - // { - // vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); - // } - // vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), SV_CMVN_NAME); - // vad_config_path = PathAppend(model_path.at(MODEL_DIR), SV_CONFIG_NAME); - - // mm->InitSv(vad_model_path, vad_cmvn_path, vad_config_path, thread_num); - // std::vector> result = mm->Infer(wave); - // delete mm; - // return result; - // } - } // namespace funasr \ No newline at end of file diff --git a/runtime/onnxruntime/src/campplus-sv.cpp b/runtime/onnxruntime/src/campplus-sv.cpp index df66b3167..4e8f58fb9 100644 --- a/runtime/onnxruntime/src/campplus-sv.cpp +++ b/runtime/onnxruntime/src/campplus-sv.cpp @@ -7,29 +7,16 @@ #include "precomp.h" #include -template -void print_vec_shape(const std::vector> &data) -{ - std::cout << "vec_shape= [" << data.size() << ", "; - if (!data.empty()) - { - std::cout << data[0].size(); - } - std::cout << "]" << std::endl; -} - namespace funasr { - void CamPPlusSv::InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num) + void CamPPlusSv::InitSv(const std::string &model, const std::string &config, int thread_num) { session_options_.SetIntraOpNumThreads(thread_num); session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL); session_options_.DisableCpuMemArena(); ReadModel(model.c_str()); - // LoadCmvn(cmvn.c_str()); LoadConfigFromYaml(config.c_str()); - } void CamPPlusSv::LoadConfigFromYaml(const char *filename) @@ -104,8 +91,8 @@ namespace funasr shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type - << " dims=" << shape.str(); + // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type + // << " dims=" << shape.str(); (*in_names)[i] = name.get(); name.release(); } @@ -125,8 +112,8 @@ namespace funasr shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type - << " dims=" << shape.str(); + // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type + // << " dims=" << shape.str(); (*out_names)[i] = name.get(); name.release(); } @@ -205,116 +192,6 @@ namespace funasr } } - void CamPPlusSv::LoadCmvn(const char *filename) - { - try - { - using namespace std; - ifstream cmvn_stream(filename); - if (!cmvn_stream.is_open()) - { - LOG(ERROR) << "Failed to open file: " << filename; - exit(-1); - } - string line; - - while (getline(cmvn_stream, line)) - { - istringstream iss(line); - vector line_item{istream_iterator{iss}, istream_iterator{}}; - if (line_item[0] == "") - { - getline(cmvn_stream, line); - istringstream means_lines_stream(line); - vector means_lines{istream_iterator{means_lines_stream}, istream_iterator{}}; - if (means_lines[0] == "") - { - for (int j = 3; j < means_lines.size() - 1; j++) - { - means_list_.push_back(stof(means_lines[j])); - } - continue; - } - } - else if (line_item[0] == "") - { - getline(cmvn_stream, line); - istringstream vars_lines_stream(line); - vector vars_lines{istream_iterator{vars_lines_stream}, istream_iterator{}}; - if (vars_lines[0] == "") - { - for (int j = 3; j < vars_lines.size() - 1; j++) - { - vars_list_.push_back(stof(vars_lines[j])); - } - continue; - } - } - } - } - catch (std::exception const &e) - { - LOG(ERROR) << "Error when load vad cmvn : " << e.what(); - exit(-1); - } - } - - void CamPPlusSv::LfrCmvn(std::vector> &vad_feats) - { - - // std::vector> out_feats; - // int T = vad_feats.size(); - // int T_lrf = ceil(1.0 * T / lfr_n); - - // // Pad frames at start(copy first frame) - // for (int i = 0; i < (lfr_m - 1) / 2; i++) - // { - // vad_feats.insert(vad_feats.begin(), vad_feats[0]); - // } - // // Merge lfr_m frames as one,lfr_n frames per window - // T = T + (lfr_m - 1) / 2; - // std::vector p; - // for (int i = 0; i < T_lrf; i++) - // { - // if (lfr_m <= T - i * lfr_n) - // { - // for (int j = 0; j < lfr_m; j++) - // { - // p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); - // } - // out_feats.emplace_back(p); - // p.clear(); - // } - // else - // { - // // Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame) - // int num_padding = lfr_m - (T - i * lfr_n); - // for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) - // { - // p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); - // } - // for (int j = 0; j < num_padding; j++) - // { - // p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end()); - // } - // out_feats.emplace_back(p); - // p.clear(); - // } - // } - // //Apply cmvn - std::vector> out_feats; - out_feats = vad_feats; - // only Apply cmvn - for (auto &out_feat : out_feats) - { - // for (int j = 0; j < means_list_.size(); j++) { - for (int j = 0; j < std::min(out_feats[0].size(), means_list_.size()); j++) - { - out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j]; - } - } - vad_feats = out_feats; - } void CamPPlusSv::SubMean(std::vector> &voice_feats) { if (voice_feats.size() > 0) diff --git a/runtime/onnxruntime/src/campplus-sv.h b/runtime/onnxruntime/src/campplus-sv.h index 31d625b8e..ee1ac6e6e 100644 --- a/runtime/onnxruntime/src/campplus-sv.h +++ b/runtime/onnxruntime/src/campplus-sv.h @@ -19,7 +19,7 @@ namespace funasr public: CamPPlusSv(); ~CamPPlusSv(); - void InitSv(const std::string &model, const std::string &cmvn, const std::string &config, int thread_num); + void InitSv(const std::string &model, const std::string &config, int thread_num); std::vector> Infer(std::vector &waves); void Forward( const std::vector> &chunk_feats, @@ -30,29 +30,19 @@ namespace funasr Ort::SessionOptions session_options_; std::vector cam_in_names_; std::vector cam_out_names_; - std::vector> in_cache_; - knf::FbankOptions fbank_opts_; - std::vector means_list_; - std::vector vars_list_; int sample_rate_ = MODEL_SAMPLE_RATE; - int lfr_m = VAD_LFR_M; - int lfr_n = VAD_LFR_N; private: void ReadModel(const char *cam_model); void LoadConfigFromYaml(const char *filename); - static void GetInputOutputInfo( const std::shared_ptr &session, std::vector *in_names, std::vector *out_names); void FbankKaldi(float sample_rate, std::vector> &vad_feats, std::vector &waves); - - void LfrCmvn(std::vector> &vad_feats); - void LoadCmvn(const char *filename); void SubMean(std::vector>& voice_feats); }; diff --git a/runtime/onnxruntime/src/commonfunc.h b/runtime/onnxruntime/src/commonfunc.h index 034da5e22..ccb6b5a70 100644 --- a/runtime/onnxruntime/src/commonfunc.h +++ b/runtime/onnxruntime/src/commonfunc.h @@ -12,8 +12,8 @@ typedef struct std::string stamp_sents; std::string tpass_msg; float snippet_time; - int speaker_idx = -999; - std::vector speaker_emb; + std::vector speaker_idxs; + std::vector> speaker_embs; }FUNASR_RECOG_RESULT; typedef struct diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp index 864d53765..fe37f3a98 100644 --- a/runtime/onnxruntime/src/funasrruntime.cpp +++ b/runtime/onnxruntime/src/funasrruntime.cpp @@ -50,6 +50,11 @@ return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size); } + _FUNASRAPI void FunTpassOnlineReset(FUNASR_HANDLE tpass_online_handle) + { + funasr::TpassOnlineCacheReset(tpass_online_handle); + } + // APIs for ASR Infer _FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format) { @@ -206,8 +211,9 @@ // APIs for Offline-stream Infer _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, - FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector> &hw_emb, - int sampling_rate, std::string wav_format, bool itn, FUNASR_DEC_HANDLE dec_handle) + FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector> &hw_emb, + std::vector> &voice_feats, int sampling_rate, std::string wav_format, + bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle) { funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; if (!offline_stream) @@ -307,7 +313,7 @@ p_result->msg = punc_res; } #if !defined(__APPLE__) - if(offline_stream->UseITN() && itn){ + if(offline_stream->UseITN() && use_itn){ string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg); if(!(p_result->stamp).empty()){ std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp); @@ -319,13 +325,14 @@ } #endif if (!(p_result->stamp).empty()){ - p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp); + p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp, p_result->speaker_idxs); } return p_result; } _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, - const std::vector> &hw_emb, int sampling_rate, bool itn, FUNASR_DEC_HANDLE dec_handle) + const std::vector> &hw_emb, std::vector> &voice_feats, + int sampling_rate, bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle) { funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; if (!offline_stream) @@ -357,11 +364,13 @@ } std::vector index_vector={0}; int msg_idx = 0; + int svs_idx = 0; if(offline_stream->UseVad()){ audio.CutSplit(offline_stream, index_vector); } std::vector msgs(index_vector.size()); std::vector msg_stimes(index_vector.size()); + std::vector svs(index_vector.size(), -1); float** buff; int* len; @@ -387,7 +396,30 @@ msg_idx++; }else{ LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size(); - } + } + } + + if(offline_stream->UseSv() && use_sv){ + for(int index=0; index 1600){ + if (voice_feats.size() < MAX_SPKS_NUM){ + std::vector wave(buff[index], buff[index]+len[index]); + std::vector> sv_result = offline_stream->sv_handle->Infer(wave); + float threshold = offline_stream->sv_handle->threshold; + int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats, threshold); + if(svs_idx < index_vector.size()){ + svs[index_vector[svs_idx]] = speaker_idx; + svs_idx++; + }else{ + LOG(ERROR) << "svs_idx: " << svs_idx <<" is out of range " << index_vector.size(); + } + }else{ + LOG(ERROR) << "Exceeding the maximum speaker limit!"; + LOG(ERROR) << "speaker_idx: " << MAX_SPKS_NUM; + } + } + } + p_result->speaker_idxs = svs; } // release @@ -429,7 +461,7 @@ p_result->msg = punc_res; } #if !defined(__APPLE__) - if(offline_stream->UseITN() && itn){ + if(offline_stream->UseITN() && use_itn){ string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg); if(!(p_result->stamp).empty()){ std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp); @@ -440,8 +472,8 @@ p_result->msg = msg_itn; } #endif - if (!(p_result->stamp).empty()){ - p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp); + if (!(p_result->stamp).empty() || !(p_result->speaker_idxs.empty())){ + p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp, p_result->speaker_idxs); } return p_result; } @@ -473,11 +505,10 @@ //#endif // APIs for 2pass-stream Infer - _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, - std::vector>& voice_feats, bool sv_mode, const char* sz_buf, - int n_len, std::vector> &punc_cache, bool input_finished, + _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, + int n_len, std::vector> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode, - const std::vector> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle) + const std::vector> &hw_emb, bool use_itn, bool use_sv, FUNASR_DEC_HANDLE dec_handle) { funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle; funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle; @@ -501,9 +532,6 @@ funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get(); if (!punc_online_handle) - return nullptr; - funasr::SvModel* sv_handle = (tpass_stream->sv_handle).get(); - if (!sv_handle &&sv_mode) return nullptr; if(wav_format == "pcm" || wav_format == "PCM"){ @@ -532,7 +560,7 @@ p_result->tpass_msg = msg_punc; #if !defined(__APPLE__) // ITN - if(tpass_stream->UseITN() && itn){ + if(tpass_stream->UseITN() && use_itn){ string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc); p_result->tpass_msg = msg_itn; } @@ -574,21 +602,21 @@ msg = msg_vec[0]; //sv-cam for Speaker verification - if(sv_mode&&frame->len>1600)//Filter audio clips less than 100ms + if(tpass_stream->UseSv() && use_sv &&frame->len>1600)//Filter audio clips less than 100ms { + std::vector>& voice_feats = tpass_online_stream->voice_feats; if (voice_feats.size()wave(frame->data,frame->data+frame->len); - std::vector>sv_result=sv_handle->Infer(wave); - // std::vector>sv_result = CamPPlusSvInfer(sv_handle, wave); - float threshold =sv_handle->threshold; - int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats,threshold); - p_result->speaker_idx = speaker_idx; + std::vectorwave(frame->data, frame->data+frame->len); + std::vector> sv_result = tpass_stream->sv_handle->Infer(wave); + float threshold = tpass_stream->sv_handle->threshold; + int speaker_idx = funasr::GetSpeakersID(sv_result[0], voice_feats, threshold); + // p_result->speaker_idx = speaker_idx; } else { LOG(ERROR)<<"Exceeding the maximum speaker limit!\n"; - p_result->speaker_idx = MAX_SPKS_NUM; + // p_result->speaker_idx = MAX_SPKS_NUM; } } //timestamp @@ -612,7 +640,7 @@ } p_result->tpass_msg = msg_punc; #if !defined(__APPLE__) - if(tpass_stream->UseITN() && itn){ + if(tpass_stream->UseITN() && use_itn){ string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc); // TimestampSmooth if(!(p_result->stamp).empty()){ @@ -625,7 +653,7 @@ } #endif if (!(p_result->stamp).empty()){ - p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp); + p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp, p_result->speaker_idxs); } if(frame != nullptr){ delete frame; @@ -902,10 +930,7 @@ // APIs for CamPPlusSv Infer _FUNASRAPI FUNASR_HANDLE CamPPlusSvInit(std::map& model_path, int thread_num) { - // funasr::SvModel *mm = funasr::CreateSvModel(model_path, thread_num); - // return mm; - // std::vector wave; - funasr::SvModel* mm = funasr::CreateAndInferSvModel(model_path, thread_num); + funasr::SvModel* mm = funasr::CreateSVModel(model_path, thread_num); return mm; } @@ -935,7 +960,7 @@ _FUNASRAPI const int FunASRGetSvResult(FUNASR_RESULT result, int n_index) if (!p_result) return -1; - return p_result->speaker_idx; + return p_result->speaker_idxs[n_index]; } _FUNASRAPI const std::vector FunASRGetSvEmbResult(FUNASR_RESULT result, int n_index) { @@ -944,5 +969,5 @@ _FUNASRAPI const std::vector FunASRGetSvEmbResult(FUNASR_RESULT result, i if (!p_result) return speaker_emb; - return p_result->speaker_emb; + return p_result->speaker_embs[n_index]; } \ No newline at end of file diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp index 166d3c9ea..3aa71aed4 100644 --- a/runtime/onnxruntime/src/offline-stream.cpp +++ b/runtime/onnxruntime/src/offline-stream.cpp @@ -138,6 +138,31 @@ OfflineStream::OfflineStream(std::map& model_path, int } } #endif + + // sv cam + if (model_path.find(SV_DIR) != model_path.end() && model_path.at(SV_DIR) != "") + { + string sv_model_path; + string sv_config_path; + + sv_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME); + if (model_path.find(SV_QUANT) != model_path.end() && model_path.at(SV_QUANT) == "true") + { + sv_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME); + } + sv_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME); + if (access(sv_model_path.c_str(), F_OK) != 0 || + access(sv_config_path.c_str(), F_OK) != 0) + { + LOG(INFO) << "CAMPlusPlus model file is not exist, skip load model."; + }else + { + sv_handle = make_unique(); + sv_handle->InitSv(sv_model_path, sv_config_path, thread_num); + use_sv = true; + } + } + } OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num, bool use_gpu, int batch_size) diff --git a/runtime/onnxruntime/src/tpass-online-stream.cpp b/runtime/onnxruntime/src/tpass-online-stream.cpp index 7788e0b12..9a74d0896 100644 --- a/runtime/onnxruntime/src/tpass-online-stream.cpp +++ b/runtime/onnxruntime/src/tpass-online-stream.cpp @@ -18,6 +18,11 @@ TpassOnlineStream::TpassOnlineStream(TpassStream* tpass_stream, std::vector } } +void TpassOnlineCacheReset(void* tpass_online_stream){ + TpassOnlineStream* tpass_online_obj = (TpassOnlineStream*)tpass_online_stream; + tpass_online_obj->voice_feats.clear(); +} + TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector chunk_size) { TpassOnlineStream *mm; diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp index 42806dc55..e4349c721 100644 --- a/runtime/onnxruntime/src/tpass-stream.cpp +++ b/runtime/onnxruntime/src/tpass-stream.cpp @@ -126,10 +126,9 @@ TpassStream::TpassStream(std::map& model_path, int thr #endif // sv cam - if (model_path.find(SV_DIR) != model_path.end()) + if (model_path.find(SV_DIR) != model_path.end() && model_path.at(SV_DIR) != "") { string sv_model_path; - string sv_cmvn_path; string sv_config_path; sv_model_path = PathAppend(model_path.at(SV_DIR), MODEL_NAME); @@ -137,18 +136,16 @@ TpassStream::TpassStream(std::map& model_path, int thr { sv_model_path = PathAppend(model_path.at(SV_DIR), QUANT_MODEL_NAME); } - sv_cmvn_path = PathAppend(model_path.at(SV_DIR), SV_CMVN_NAME); sv_config_path = PathAppend(model_path.at(SV_DIR), SV_CONFIG_NAME); if (access(sv_model_path.c_str(), F_OK) != 0 || - // access(vad_cmvn_path.c_str(), F_OK) != 0 || access(sv_config_path.c_str(), F_OK) != 0) { - LOG(INFO) << "CAMPlusPlus model file is not exist, skip load vad model."; + LOG(INFO) << "CAMPlusPlus model file is not exist, skip load model."; } else { sv_handle = make_unique(); - sv_handle->InitSv(sv_model_path, sv_cmvn_path, sv_config_path, thread_num); + sv_handle->InitSv(sv_model_path, sv_config_path, thread_num); use_sv = true; } } diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp index 4ad8a65fb..625c236d9 100644 --- a/runtime/onnxruntime/src/util.cpp +++ b/runtime/onnxruntime/src/util.cpp @@ -566,7 +566,7 @@ std::string TimestampSmooth(std::string &text, std::string &text_itn, std::strin return timestamps_str; } -std::string TimestampSentence(std::string &text, std::string &str_time){ +std::string TimestampSentence(std::string &text, std::string &str_time, std::vector speaker_idxs){ std::vector characters; funasr::TimestampSplitChiEngCharacters(text, characters); vector> timestamps = funasr::ParseTimestamps(str_time); diff --git a/runtime/onnxruntime/src/util.h b/runtime/onnxruntime/src/util.h index 0eda98da9..c9078488d 100644 --- a/runtime/onnxruntime/src/util.h +++ b/runtime/onnxruntime/src/util.h @@ -47,7 +47,7 @@ void TimestampSplitChiEngCharacters(const std::string &input_str, std::vector &characters); std::string VectorToString(const std::vector>& vec, bool out_empty=true); std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time); -std::string TimestampSentence(std::string &text, std::string &str_time); +std::string TimestampSentence(std::string &text, std::string &str_time, std::vector speaker_idxs); std::vector split(const std::string &s, char delim); template @@ -71,6 +71,5 @@ void ExtractHws(string hws_file, unordered_map &hws_map); void ExtractHws(string hws_file, unordered_map &hws_map, string& nn_hotwords_); float CosineSimilarity(const std::vector &emb1, const std::vector &emb2); int GetSpeakersID(const std::vector &emb1, std::vector> &emb_list, float threshold = 0.40); - } // namespace funasr #endif diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp index 8c8cab419..bbe7185bb 100644 --- a/runtime/websocket/bin/websocket-server-2pass.cpp +++ b/runtime/websocket/bin/websocket-server-2pass.cpp @@ -140,7 +140,7 @@ void WebSocketServer::do_decoder( subvector.data(), subvector.size(), punc_cache, false, audio_fs, wav_format, (ASR_TYPE)asr_mode_, - hotwords_embedding, itn, decoder_handle); + hotwords_embedding, itn, true, decoder_handle); } else { scoped_lock guard(thread_lock); @@ -177,7 +177,7 @@ void WebSocketServer::do_decoder( buffer.data(), buffer.size(), punc_cache, is_final, audio_fs, wav_format, (ASR_TYPE)asr_mode_, - hotwords_embedding, itn, decoder_handle); + hotwords_embedding, itn, true, decoder_handle); } else { scoped_lock guard(thread_lock); msg["access_num"]=(int)msg["access_num"]-1;