diff --git a/funasr/runtime/onnxruntime/include/model.h b/funasr/runtime/onnxruntime/include/model.h index 8019a07bb..7f1e0acb3 100644 --- a/funasr/runtime/onnxruntime/include/model.h +++ b/funasr/runtime/onnxruntime/include/model.h @@ -18,6 +18,7 @@ class Model { virtual void InitHwCompiler(const std::string &hw_model, int thread_num){}; virtual void InitSegDict(const std::string &seg_dict_model){}; virtual std::vector> CompileHotwordEmbedding(std::string &hotwords){return std::vector>();}; + virtual std::string GetLang(){return "";}; }; Model *CreateModel(std::map& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE); diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h index 4266eea34..214c7700a 100644 --- a/funasr/runtime/onnxruntime/include/punc-model.h +++ b/funasr/runtime/onnxruntime/include/punc-model.h @@ -12,8 +12,8 @@ class PuncModel { public: virtual ~PuncModel(){}; virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0; - virtual std::string AddPunc(const char* sz_input){return "";}; - virtual std::string AddPunc(const char* sz_input, std::vector& arr_cache){return "";}; + virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";}; + virtual std::string AddPunc(const char* sz_input, std::vector& arr_cache, std::string language="zh-cn"){return "";}; }; PuncModel *CreatePuncModel(std::map& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE); diff --git a/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp index 5fe692b1d..51f2a6a78 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp +++ b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp @@ -50,7 +50,7 @@ CTTransformerOnline::~CTTransformerOnline() { } -string CTTransformerOnline::AddPunc(const char* sz_input, vector &arr_cache) +string CTTransformerOnline::AddPunc(const char* sz_input, vector &arr_cache, std::string language) { string strResult; vector strOut; diff --git a/funasr/runtime/onnxruntime/src/ct-transformer-online.h b/funasr/runtime/onnxruntime/src/ct-transformer-online.h index 5db183a91..ea7edb7fa 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer-online.h +++ b/funasr/runtime/onnxruntime/src/ct-transformer-online.h @@ -29,7 +29,7 @@ public: void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num); ~CTTransformerOnline(); vector Infer(vector input_data, int nCacheSize); - string AddPunc(const char* sz_input, vector &arr_cache); + string AddPunc(const char* sz_input, vector &arr_cache, std::string language="zh-cn"); void Transport(vector& In, int nRows, int nCols); void VadMask(int size, int vad_pos,vector& Result); void Triangle(int text_length, vector& Result); diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp index a6c75fbe6..64a70da94 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp +++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp @@ -46,7 +46,7 @@ CTTransformer::~CTTransformer() { } -string CTTransformer::AddPunc(const char* sz_input) +string CTTransformer::AddPunc(const char* sz_input, std::string language) { string strResult; vector strOut; @@ -139,8 +139,28 @@ string CTTransformer::AddPunc(const char* sz_input) } } } - for (auto& item : NewSentenceOut) + + for (auto& item : NewSentenceOut){ strResult += item; + } + + if(language == "en-bpe"){ + std::vector chineseSymbols; + chineseSymbols.push_back(","); + chineseSymbols.push_back("。"); + chineseSymbols.push_back("、"); + chineseSymbols.push_back("?"); + + std::string englishSymbols = ",.,?"; + for (size_t i = 0; i < chineseSymbols.size(); i++) { + size_t pos = 0; + while ((pos = strResult.find(chineseSymbols[i], pos)) != std::string::npos) { + strResult.replace(pos, 3, 1, englishSymbols[i]); + pos++; + } + } + } + return strResult; } diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.h b/funasr/runtime/onnxruntime/src/ct-transformer.h index 49ed1b7bf..b33dcf55b 100644 --- a/funasr/runtime/onnxruntime/src/ct-transformer.h +++ b/funasr/runtime/onnxruntime/src/ct-transformer.h @@ -29,6 +29,6 @@ public: void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num); ~CTTransformer(); vector Infer(vector input_data); - string AddPunc(const char* sz_input); + string AddPunc(const char* sz_input, std::string language="zh-cn"); }; } // namespace funasr \ No newline at end of file diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp index 0d4af5cff..73738c76f 100644 --- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp +++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp @@ -282,7 +282,8 @@ extern "C" { p_result->stamp += cur_stamp + "]"; } if(offline_stream->UsePunc()){ - string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str()); + string lang = (offline_stream->asr_handle)->GetLang(); + string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang); p_result->msg = punc_res; } #if !defined(__APPLE__) @@ -363,7 +364,8 @@ extern "C" { p_result->stamp += cur_stamp + "]"; } if(offline_stream->UsePunc()){ - string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str()); + string lang = (offline_stream->asr_handle)->GetLang(); + string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang); p_result->msg = punc_res; } #if !defined(__APPLE__) diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h index bac8fad8d..455078e1b 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.h +++ b/funasr/runtime/onnxruntime/src/paraformer.h @@ -33,7 +33,6 @@ namespace funasr { vector hw_m_szInputNames; vector hw_m_szOutputNames; bool use_hotword; - std::string language="zh-cn"; public: Paraformer(); @@ -55,6 +54,7 @@ namespace funasr { string PostProcess(std::vector &raw_char, std::vector> ×tamp_list); string Rescoring(); + string GetLang(){return language;}; knf::FbankOptions fbank_opts_; vector means_list_; @@ -71,6 +71,8 @@ namespace funasr { vector m_szInputNames; vector m_szOutputNames; + std::string language="zh-cn"; + // paraformer-online std::shared_ptr encoder_session_ = nullptr; std::shared_ptr decoder_session_ = nullptr; diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp index 3f5191100..2babc4082 100644 --- a/funasr/runtime/onnxruntime/src/vocab.cpp +++ b/funasr/runtime/onnxruntime/src/vocab.cpp @@ -75,6 +75,21 @@ bool Vocab::IsChinese(string ch) return false; } +string Vocab::WordFormat(std::string word) +{ + if(word == "i"){ + return "I"; + }else if(word == "i'm"){ + return "I'm"; + }else if(word == "i've"){ + return "I've"; + }else if(word == "i'll"){ + return "I'll"; + }else{ + return word; + } +} + string Vocab::Vector2StringV2(vector in, std::string language) { int i; @@ -94,6 +109,7 @@ string Vocab::Vector2StringV2(vector in, std::string language) size_t found = word.find(unicodeChar); if(found != std::string::npos){ if (combine != ""){ + combine = WordFormat(combine); if (words.size() != 0){ combine = " " + combine; } @@ -164,6 +180,7 @@ string Vocab::Vector2StringV2(vector in, std::string language) } if (language == "en-bpe" and combine != ""){ + combine = WordFormat(combine); if (words.size() != 0){ combine = " " + combine; } diff --git a/funasr/runtime/onnxruntime/src/vocab.h b/funasr/runtime/onnxruntime/src/vocab.h index eecb9c861..23b4bd6e9 100644 --- a/funasr/runtime/onnxruntime/src/vocab.h +++ b/funasr/runtime/onnxruntime/src/vocab.h @@ -23,6 +23,7 @@ class Vocab { bool IsChinese(string ch); void Vector2String(vector in, std::vector &preds); string Vector2StringV2(vector in, std::string language=""); + string WordFormat(std::string word); int GetIdByToken(const std::string &token); }; diff --git a/funasr/runtime/websocket/bin/funasr-wss-server.cpp b/funasr/runtime/websocket/bin/funasr-wss-server.cpp index e64667b28..eb1402b26 100644 --- a/funasr/runtime/websocket/bin/funasr-wss-server.cpp +++ b/funasr/runtime/websocket/bin/funasr-wss-server.cpp @@ -195,11 +195,16 @@ int main(int argc, char* argv[]) { size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404"); if (found != std::string::npos) { model_path["model-revision"]="v1.2.4"; - }else{ - found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); - if (found != std::string::npos) { - model_path["model-revision"]="v1.0.5"; - } + } + + found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); + if (found != std::string::npos) { + model_path["model-revision"]="v1.0.5"; + } + + found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020"); + if (found != std::string::npos) { + model_path["model-revision"]="v1.0.0"; } // modelscope