diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md index 3e34a678a..4ed184f64 100644 --- a/funasr/runtime/onnxruntime/readme.md +++ b/funasr/runtime/onnxruntime/readme.md @@ -127,6 +127,8 @@ For example: ### funasr-onnx-offline-rtf ```shell ./funasr-onnx-offline-rtf --model-dir [--quantize ] + [--vad-dir ] [--vad-quant ] + [--punc-dir ] [--punc-quant ] --wav-path --thread-num [--] [--version] [-h] Where: @@ -136,6 +138,17 @@ Where: (required) the model path, which contains model.onnx, config.yaml, am.mvn --quantize false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir + + --vad-dir + the vad model path, which contains model.onnx, vad.yaml, vad.mvn + --vad-quant + false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir + + --punc-dir + the punc model path, which contains model.onnx, punc.yaml + --punc-quant + false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir + --wav-path (required) the input could be: wav_path, e.g.: asr_example.wav; diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp index 6ba65c6c4..cf1469d3e 100644 --- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp +++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp @@ -39,7 +39,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector wav_list, // warm up for (size_t i = 0; i < 1; i++) { - FUNASR_RESULT result=FunASRInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000); + FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000); } while (true) { @@ -50,7 +50,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector wav_list, } gettimeofday(&start, NULL); - FUNASR_RESULT result=FunASRInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000); + FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000); gettimeofday(&end, NULL); seconds = (end.tv_sec - start.tv_sec); @@ -102,12 +102,20 @@ int main(int argc, char *argv[]) TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0"); TCLAP::ValueArg model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); TCLAP::ValueArg quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); + TCLAP::ValueArg vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); + TCLAP::ValueArg vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); + TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); + TCLAP::ValueArg punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); TCLAP::ValueArg wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); TCLAP::ValueArg thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t"); cmd.add(model_dir); cmd.add(quantize); + cmd.add(vad_dir); + cmd.add(vad_quant); + cmd.add(punc_dir); + cmd.add(punc_quant); cmd.add(wav_path); cmd.add(thread_num); cmd.parse(argc, argv); @@ -115,11 +123,15 @@ int main(int argc, char *argv[]) std::map model_path; GetValue(model_dir, MODEL_DIR, model_path); GetValue(quantize, QUANTIZE, model_path); + GetValue(vad_dir, VAD_DIR, model_path); + GetValue(vad_quant, VAD_QUANT, model_path); + GetValue(punc_dir, PUNC_DIR, model_path); + GetValue(punc_quant, PUNC_QUANT, model_path); GetValue(wav_path, WAV_PATH, model_path); struct timeval start, end; gettimeofday(&start, NULL); - FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1); + FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1); if (!asr_handle) { @@ -132,7 +144,7 @@ int main(int argc, char *argv[]) long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; - // read wav_scp + // read wav_path vector wav_list; string wav_path_ = model_path.at(WAV_PATH); if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){