mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
support ngram and fst hotword for 2pass-offline (#1205)
This commit is contained in:
parent
b635c062f1
commit
b8825902d9
@ -44,12 +44,17 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std:
|
||||
}
|
||||
|
||||
void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
|
||||
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) {
|
||||
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_,
|
||||
float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map<string, int> hws_map) {
|
||||
|
||||
struct timeval start, end;
|
||||
long seconds = 0;
|
||||
float n_total_length = 0.0f;
|
||||
long n_total_time = 0;
|
||||
|
||||
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale);
|
||||
// load hotwords list and build graph
|
||||
FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map);
|
||||
|
||||
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
|
||||
|
||||
@ -90,7 +95,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
|
||||
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
if (result)
|
||||
{
|
||||
FunASRFreeResult(result);
|
||||
@ -139,7 +145,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
|
||||
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -197,6 +204,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
*total_time = n_total_time;
|
||||
}
|
||||
}
|
||||
FunWfstDecoderUnloadHwsRes(decoder_handle);
|
||||
FunASRWfstDecoderUninit(decoder_handle);
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
}
|
||||
|
||||
@ -215,6 +224,11 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
@ -231,6 +245,11 @@ int main(int argc, char** argv)
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(asr_mode);
|
||||
@ -248,6 +267,7 @@ int main(int argc, char** argv)
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
|
||||
@ -272,6 +292,14 @@ int main(int argc, char** argv)
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
float glob_beam = 3.0f;
|
||||
float lat_beam = 3.0f;
|
||||
float am_sc = 10.0f;
|
||||
if (lm_dir.isSet()) {
|
||||
glob_beam = global_beam.getValue();
|
||||
lat_beam = lattice_beam.getValue();
|
||||
am_sc = am_scale.getValue();
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
@ -321,7 +349,8 @@ int main(int argc, char** argv)
|
||||
int rtf_threds = thread_num_.getValue();
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_));
|
||||
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
|
||||
glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
|
||||
@ -51,6 +51,11 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
|
||||
@ -65,6 +70,11 @@ int main(int argc, char** argv)
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
@ -81,6 +91,7 @@ int main(int argc, char** argv)
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
@ -106,6 +117,16 @@ int main(int argc, char** argv)
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
float glob_beam = 3.0f;
|
||||
float lat_beam = 3.0f;
|
||||
float am_sc = 10.0f;
|
||||
if (lm_dir.isSet()) {
|
||||
glob_beam = global_beam.getValue();
|
||||
lat_beam = lattice_beam.getValue();
|
||||
am_sc = am_scale.getValue();
|
||||
}
|
||||
// init wfst decoder
|
||||
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc);
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
@ -146,6 +167,9 @@ int main(int argc, char** argv)
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
// load hotwords list and build graph
|
||||
FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map);
|
||||
|
||||
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
|
||||
// init online features
|
||||
std::vector<int> chunk_size = {5,10,5};
|
||||
@ -191,7 +215,9 @@ int main(int argc, char** argv)
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
|
||||
speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm",
|
||||
(ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -235,10 +261,12 @@ int main(int argc, char** argv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
FunWfstDecoderUnloadHwsRes(decoder_handle);
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunASRWfstDecoderUninit(decoder_handle);
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
FunTpassUninit(tpass_handle);
|
||||
return 0;
|
||||
|
||||
@ -54,7 +54,6 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
// warm up
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
FunOfflineReset(asr_handle, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle);
|
||||
if(result){
|
||||
FunASRFreeResult(result);
|
||||
|
||||
@ -50,7 +50,7 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
|
||||
@ -119,7 +119,7 @@ _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
|
||||
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true,
|
||||
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
|
||||
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true);
|
||||
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
|
||||
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
|
||||
@ -254,9 +254,9 @@ float Audio::GetTimeLen()
|
||||
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n)
|
||||
{
|
||||
LOG(INFO) << "Creating a resampler:\n"
|
||||
<< " in_sample_rate: "<< sampling_rate << "\n"
|
||||
<< " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
|
||||
LOG(INFO) << "Creating a resampler: "
|
||||
<< " in_sample_rate: "<< sampling_rate
|
||||
<< " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, dest_sample_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
@ -437,7 +437,7 @@
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
|
||||
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
|
||||
int sampling_rate, std::string wav_format, ASR_TYPE mode,
|
||||
const std::vector<std::vector<float>> &hw_emb, bool itn)
|
||||
const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
|
||||
{
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
|
||||
@ -511,7 +511,12 @@
|
||||
// timestamp
|
||||
std::string cur_stamp = "[";
|
||||
while(audio->FetchTpass(frame) > 0){
|
||||
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
|
||||
// dec reset
|
||||
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
|
||||
|
||||
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
|
||||
if(msg_vec.size()==0){
|
||||
@ -761,9 +766,15 @@
|
||||
if (asr_type == ASR_OFFLINE) {
|
||||
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
} else if (asr_type == ASR_TWO_PASS){
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -193,8 +193,7 @@ void Paraformer::InitLm(const std::string &lm_file,
|
||||
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
|
||||
fst::Fst<fst::StdArc>::Read(lm_file));
|
||||
if (lm_){
|
||||
if (vocab) { delete vocab; }
|
||||
vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
|
||||
lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
|
||||
LOG(INFO) << "Successfully load lm file " << lm_file;
|
||||
}else{
|
||||
LOG(ERROR) << "Failed to load lm file " << lm_file;
|
||||
@ -310,6 +309,9 @@ Paraformer::~Paraformer()
|
||||
if(vocab){
|
||||
delete vocab;
|
||||
}
|
||||
if(lm_vocab){
|
||||
delete lm_vocab;
|
||||
}
|
||||
if(seg_dict){
|
||||
delete seg_dict;
|
||||
}
|
||||
@ -687,6 +689,11 @@ Vocab* Paraformer::GetVocab()
|
||||
return vocab;
|
||||
}
|
||||
|
||||
Vocab* Paraformer::GetLmVocab()
|
||||
{
|
||||
return lm_vocab;
|
||||
}
|
||||
|
||||
PhoneSet* Paraformer::GetPhoneSet()
|
||||
{
|
||||
return phone_set_;
|
||||
|
||||
@ -20,6 +20,7 @@ namespace funasr {
|
||||
*/
|
||||
private:
|
||||
Vocab* vocab = nullptr;
|
||||
Vocab* lm_vocab = nullptr;
|
||||
SegDict* seg_dict = nullptr;
|
||||
PhoneSet* phone_set_ = nullptr;
|
||||
//const float scale = 22.6274169979695;
|
||||
@ -65,6 +66,7 @@ namespace funasr {
|
||||
string FinalizeDecode(WfstDecoder* &wfst_decoder,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
Vocab* GetVocab();
|
||||
Vocab* GetLmVocab();
|
||||
PhoneSet* GetPhoneSet();
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
|
||||
@ -66,6 +66,20 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
|
||||
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Lm resource
|
||||
if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
|
||||
string fst_path, lm_config_path, lex_path;
|
||||
fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
|
||||
lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
|
||||
lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
|
||||
if (access(lex_path.c_str(), F_OK) != 0 )
|
||||
{
|
||||
LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
|
||||
}else{
|
||||
asr_handle->InitLm(fst_path, lm_config_path, lex_path);
|
||||
}
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
if(model_path.find(PUNC_DIR) != model_path.end()){
|
||||
|
||||
@ -192,7 +192,10 @@ class WebsocketClient {
|
||||
funasr::Audio audio(1);
|
||||
int32_t sampling_rate = audio_fs;
|
||||
std::string wav_format = "pcm";
|
||||
if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
|
||||
if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
|
||||
if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false))
|
||||
return;
|
||||
} else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
|
||||
} else {
|
||||
wav_format = "others";
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
// hotwords
|
||||
std::unordered_map<std::string, int> hws_map_;
|
||||
int fst_inc_wts_=20;
|
||||
float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
@ -120,6 +121,14 @@ int main(int argc, char* argv[]) {
|
||||
"connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
|
||||
"the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
|
||||
TCLAP::ValueArg<std::string> lm_revision(
|
||||
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD,
|
||||
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)",
|
||||
false, "/workspace/resources/hotwords.txt", "string");
|
||||
@ -128,6 +137,10 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
// add file
|
||||
cmd.add(hotword);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
@ -146,6 +159,8 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(itn_revision);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(lm_revision);
|
||||
|
||||
cmd.add(listen_ip);
|
||||
cmd.add(port);
|
||||
@ -163,6 +178,7 @@ int main(int argc, char* argv[]) {
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(hotword, HOTWORD, model_path);
|
||||
|
||||
GetValue(offline_model_revision, "offline-model-revision", model_path);
|
||||
@ -170,6 +186,11 @@ int main(int argc, char* argv[]) {
|
||||
GetValue(vad_revision, "vad-revision", model_path);
|
||||
GetValue(punc_revision, "punc-revision", model_path);
|
||||
GetValue(itn_revision, "itn-revision", model_path);
|
||||
GetValue(lm_revision, "lm-revision", model_path);
|
||||
|
||||
global_beam_ = global_beam.getValue();
|
||||
lattice_beam_ = lattice_beam.getValue();
|
||||
am_scale_ = am_scale.getValue();
|
||||
|
||||
// Download model form Modelscope
|
||||
try {
|
||||
@ -183,6 +204,7 @@ int main(int argc, char* argv[]) {
|
||||
std::string s_punc_path = model_path[PUNC_DIR];
|
||||
std::string s_punc_quant = model_path[PUNC_QUANT];
|
||||
std::string s_itn_path = model_path[ITN_DIR];
|
||||
std::string s_lm_path = model_path[LM_DIR];
|
||||
|
||||
std::string python_cmd =
|
||||
"python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
|
||||
@ -241,11 +263,18 @@ int main(int argc, char* argv[]) {
|
||||
size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.2.4";
|
||||
} else{
|
||||
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.0.5";
|
||||
}
|
||||
}
|
||||
|
||||
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.0.5";
|
||||
}
|
||||
|
||||
found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.0.0";
|
||||
s_itn_path="";
|
||||
s_lm_path="";
|
||||
}
|
||||
|
||||
if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
|
||||
@ -332,6 +361,49 @@ int main(int argc, char* argv[]) {
|
||||
LOG(INFO) << "ASR online model is not set, use default.";
|
||||
}
|
||||
|
||||
if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
|
||||
std::string python_cmd_lm;
|
||||
std::string down_lm_path;
|
||||
std::string down_lm_model;
|
||||
|
||||
if (access(s_lm_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["lm-revision"] + " --export False ";
|
||||
down_lm_path = s_lm_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_lm_path
|
||||
<< " from modelscope : ";
|
||||
python_cmd_lm = python_cmd + " --model-name " +
|
||||
s_lm_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["lm-revision"]
|
||||
+ " --export False ";
|
||||
down_lm_path =
|
||||
s_download_model_dir +
|
||||
"/" + s_lm_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_lm.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local lm model path, you can ignore the errors.";
|
||||
}
|
||||
down_lm_model = down_lm_path + "/TLG.fst";
|
||||
|
||||
if (access(down_lm_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_lm_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[LM_DIR] = down_lm_path;
|
||||
LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "LM model is not set, not executed.";
|
||||
model_path[LM_DIR] = "";
|
||||
}
|
||||
|
||||
if (!s_punc_path.empty()) {
|
||||
std::string python_cmd_punc;
|
||||
std::string down_punc_path;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
extern std::unordered_map<std::string, int> hws_map_;
|
||||
extern int fst_inc_wts_;
|
||||
extern float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
||||
websocketpp::connection_hdl hdl,
|
||||
@ -102,7 +103,8 @@ void WebSocketServer::do_decoder(
|
||||
bool itn,
|
||||
int audio_fs,
|
||||
std::string wav_format,
|
||||
FUNASR_HANDLE& tpass_online_handle) {
|
||||
FUNASR_HANDLE& tpass_online_handle,
|
||||
FUNASR_DEC_HANDLE& decoder_handle) {
|
||||
// lock for each connection
|
||||
if(!tpass_online_handle){
|
||||
scoped_lock guard(thread_lock);
|
||||
@ -131,7 +133,7 @@ void WebSocketServer::do_decoder(
|
||||
subvector.data(), subvector.size(),
|
||||
punc_cache, false, audio_fs,
|
||||
wav_format, (ASR_TYPE)asr_mode_,
|
||||
hotwords_embedding, itn);
|
||||
hotwords_embedding, itn, decoder_handle);
|
||||
|
||||
} else {
|
||||
scoped_lock guard(thread_lock);
|
||||
@ -168,7 +170,7 @@ void WebSocketServer::do_decoder(
|
||||
buffer.data(), buffer.size(), punc_cache,
|
||||
is_final, audio_fs,
|
||||
wav_format, (ASR_TYPE)asr_mode_,
|
||||
hotwords_embedding, itn);
|
||||
hotwords_embedding, itn, decoder_handle);
|
||||
} else {
|
||||
scoped_lock guard(thread_lock);
|
||||
msg["access_num"]=(int)msg["access_num"]-1;
|
||||
@ -241,6 +243,9 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
||||
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
||||
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
|
||||
data_msg->msg["is_eof"]=false; // if this connection is closed
|
||||
FUNASR_DEC_HANDLE decoder_handle =
|
||||
FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
|
||||
data_msg->decoder_handle = decoder_handle;
|
||||
data_msg->punc_cache =
|
||||
std::make_shared<std::vector<std::vector<std::string>>>(2);
|
||||
data_msg->strand_ = std::make_shared<asio::io_context::strand>(io_decoder_);
|
||||
@ -267,6 +272,9 @@ void remove_hdl(
|
||||
// finished and avoid access freed tpass_online_handle
|
||||
unique_lock guard_decoder(*(data_msg->thread_lock));
|
||||
if (data_msg->msg["access_num"]==0 && data_msg->msg["is_eof"]==true) {
|
||||
FunWfstDecoderUnloadHwsRes(data_msg->decoder_handle);
|
||||
FunASRWfstDecoderUninit(data_msg->decoder_handle);
|
||||
data_msg->decoder_handle = nullptr;
|
||||
FunTpassOnlineUninit(data_msg->tpass_online_handle);
|
||||
data_msg->tpass_online_handle = nullptr;
|
||||
data_map.erase(hdl);
|
||||
@ -431,7 +439,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
nn_hotwords += " " + pair.first;
|
||||
LOG(INFO) << pair.first << " : " << pair.second;
|
||||
}
|
||||
// FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
|
||||
FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
|
||||
|
||||
// nn
|
||||
std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords, ASR_TWO_PASS);
|
||||
@ -483,7 +491,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["itn"],
|
||||
msg_data->msg["audio_fs"],
|
||||
msg_data->msg["wav_format"],
|
||||
std::ref(msg_data->tpass_online_handle)));
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->decoder_handle)));
|
||||
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
@ -530,7 +539,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["itn"],
|
||||
msg_data->msg["audio_fs"],
|
||||
msg_data->msg["wav_format"],
|
||||
std::ref(msg_data->tpass_online_handle)));
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->decoder_handle)));
|
||||
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,7 +60,8 @@ typedef struct {
|
||||
FUNASR_HANDLE tpass_online_handle=NULL;
|
||||
std::string online_res = "";
|
||||
std::string tpass_res = "";
|
||||
std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order
|
||||
std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order
|
||||
FUNASR_DEC_HANDLE decoder_handle=NULL;
|
||||
} FUNASR_MESSAGE;
|
||||
|
||||
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
|
||||
@ -123,7 +124,8 @@ class WebSocketServer {
|
||||
bool itn,
|
||||
int audio_fs,
|
||||
std::string wav_format,
|
||||
FUNASR_HANDLE& tpass_online_handle);
|
||||
FUNASR_HANDLE& tpass_online_handle,
|
||||
FUNASR_DEC_HANDLE& decoder_handle);
|
||||
|
||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user