diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp index dfa2b1fa2..763d01ec5 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.cpp +++ b/funasr/runtime/onnxruntime/src/paraformer.cpp @@ -65,6 +65,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn for (auto& item : m_strOutputNames) m_szOutputNames.push_back(item.c_str()); vocab = new Vocab(am_config.c_str()); + LoadConfigFromYaml(am_config.c_str()); LoadCmvn(am_cmvn.c_str()); } @@ -183,6 +184,27 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &en_mode m_szOutputNames.push_back(item.c_str()); } +void Paraformer::LoadConfigFromYaml(const char* filename){ + + YAML::Node config; + try{ + config = YAML::LoadFile(filename); + }catch(exception const &e){ + LOG(ERROR) << "Error loading file, yaml file error or not exist."; + exit(-1); + } + + try{ + YAML::Node lang_conf = config["lang"]; + if (lang_conf.IsDefined()){ + language = lang_conf.as(); + } + }catch(exception const &e){ + LOG(ERROR) << "Error when load argument from vad config YAML."; + exit(-1); + } +} + void Paraformer::LoadOnlineConfigFromYaml(const char* filename){ YAML::Node config; @@ -342,7 +364,7 @@ string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums, bool hyps.push_back(max_idx); } if(!is_stamp){ - return vocab->Vector2StringV2(hyps); + return vocab->Vector2StringV2(hyps, language); }else{ std::vector char_list; std::vector> timestamp_list; @@ -707,17 +729,6 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std:: }else{ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); } -// int pos = 0; -// std::vector> logits; -// for (int j = 0; j < outputShape[1]; j++) -// { -// std::vector vec_token; -// vec_token.insert(vec_token.begin(), floatData + pos, floatData + pos + outputShape[2]); -// logits.push_back(vec_token); -// pos += outputShape[2]; -// } -// //PrintMat(logits, "logits_out"); -// result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); } catch (std::exception const &e) { diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h index 4080881b5..bac8fad8d 100644 --- a/funasr/runtime/onnxruntime/src/paraformer.h +++ b/funasr/runtime/onnxruntime/src/paraformer.h @@ -20,6 +20,7 @@ namespace funasr { //const float scale = 22.6274169979695; const float scale = 1.0; + void LoadConfigFromYaml(const char* filename); void LoadOnlineConfigFromYaml(const char* filename); void LoadCmvn(const char *filename); vector ApplyLfr(const vector &in); @@ -32,6 +33,7 @@ namespace funasr { vector hw_m_szInputNames; vector hw_m_szOutputNames; bool use_hotword; + std::string language="zh-cn"; public: Paraformer(); diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp index c29156f6a..95174c728 100644 --- a/funasr/runtime/onnxruntime/src/vocab.cpp +++ b/funasr/runtime/onnxruntime/src/vocab.cpp @@ -75,20 +75,36 @@ bool Vocab::IsChinese(string ch) return false; } -string Vocab::Vector2StringV2(vector in) +string Vocab::Vector2StringV2(vector in, std::string language) { int i; list words; int is_pre_english = false; int pre_english_len = 0; int is_combining = false; - string combine = ""; + std::string combine = ""; + std::string unicodeChar = "▁"; for (auto it = in.begin(); it != in.end(); it++) { string word = vocab[*it]; // step1 space character skips if (word == "" || word == "" || word == "") continue; + if (language == "en-bpe"){ + size_t found = word.find(unicodeChar); + if(found != std::string::npos){ + if (combine != ""){ + if (words.size() != 0){ + combine = " " + combine; + } + words.push_back(combine); + } + combine = word.substr(3); + }else{ + combine += word; + } + continue; + } // step2 combie phoneme to full word { int sub_word = !(word.find("@@") == string::npos); diff --git a/funasr/runtime/onnxruntime/src/vocab.h b/funasr/runtime/onnxruntime/src/vocab.h index 808852ac2..eecb9c861 100644 --- a/funasr/runtime/onnxruntime/src/vocab.h +++ b/funasr/runtime/onnxruntime/src/vocab.h @@ -22,7 +22,7 @@ class Vocab { int Size(); bool IsChinese(string ch); void Vector2String(vector in, std::vector &preds); - string Vector2StringV2(vector in); + string Vector2StringV2(vector in, std::string language=""); int GetIdByToken(const std::string &token); };