support en-bpe model

This commit is contained in:
雾聪 2023-10-10 16:12:40 +08:00
parent 89e68b28c5
commit 4984724f6a
11 changed files with 63 additions and 15 deletions

View File

@ -18,6 +18,7 @@ class Model {
virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
virtual void InitSegDict(const std::string &seg_dict_model){};
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
virtual std::string GetLang(){return "";};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);

View File

@ -12,8 +12,8 @@ class PuncModel {
public:
virtual ~PuncModel(){};
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
virtual std::string AddPunc(const char* sz_input){return "";};
virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){return "";};
virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache, std::string language="zh-cn"){return "";};
};
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);

View File

@ -50,7 +50,7 @@ CTTransformerOnline::~CTTransformerOnline()
{
}
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache)
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language)
{
string strResult;
vector<string> strOut;

View File

@ -29,7 +29,7 @@ public:
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
~CTTransformerOnline();
vector<int> Infer(vector<int32_t> input_data, int nCacheSize);
string AddPunc(const char* sz_input, vector<string> &arr_cache);
string AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language="zh-cn");
void Transport(vector<float>& In, int nRows, int nCols);
void VadMask(int size, int vad_pos,vector<float>& Result);
void Triangle(int text_length, vector<float>& Result);

View File

@ -46,7 +46,7 @@ CTTransformer::~CTTransformer()
{
}
string CTTransformer::AddPunc(const char* sz_input)
string CTTransformer::AddPunc(const char* sz_input, std::string language)
{
string strResult;
vector<string> strOut;
@ -139,8 +139,28 @@ string CTTransformer::AddPunc(const char* sz_input)
}
}
}
for (auto& item : NewSentenceOut)
for (auto& item : NewSentenceOut){
strResult += item;
}
if(language == "en-bpe"){
std::vector<std::string> chineseSymbols;
chineseSymbols.push_back("");
chineseSymbols.push_back("");
chineseSymbols.push_back("");
chineseSymbols.push_back("");
std::string englishSymbols = ",.,?";
for (size_t i = 0; i < chineseSymbols.size(); i++) {
size_t pos = 0;
while ((pos = strResult.find(chineseSymbols[i], pos)) != std::string::npos) {
strResult.replace(pos, 3, 1, englishSymbols[i]);
pos++;
}
}
}
return strResult;
}

View File

@ -29,6 +29,6 @@ public:
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
~CTTransformer();
vector<int> Infer(vector<int32_t> input_data);
string AddPunc(const char* sz_input);
string AddPunc(const char* sz_input, std::string language="zh-cn");
};
} // namespace funasr

View File

@ -282,7 +282,8 @@ extern "C" {
p_result->stamp += cur_stamp + "]";
}
if(offline_stream->UsePunc()){
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
string lang = (offline_stream->asr_handle)->GetLang();
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
p_result->msg = punc_res;
}
#if !defined(__APPLE__)
@ -363,7 +364,8 @@ extern "C" {
p_result->stamp += cur_stamp + "]";
}
if(offline_stream->UsePunc()){
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
string lang = (offline_stream->asr_handle)->GetLang();
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
p_result->msg = punc_res;
}
#if !defined(__APPLE__)

View File

@ -33,7 +33,6 @@ namespace funasr {
vector<const char*> hw_m_szInputNames;
vector<const char*> hw_m_szOutputNames;
bool use_hotword;
std::string language="zh-cn";
public:
Paraformer();
@ -55,6 +54,7 @@ namespace funasr {
string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list);
string Rescoring();
string GetLang(){return language;};
knf::FbankOptions fbank_opts_;
vector<float> means_list_;
@ -71,6 +71,8 @@ namespace funasr {
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
std::string language="zh-cn";
// paraformer-online
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;

View File

@ -75,6 +75,21 @@ bool Vocab::IsChinese(string ch)
return false;
}
string Vocab::WordFormat(std::string word)
{
if(word == "i"){
return "I";
}else if(word == "i'm"){
return "I'm";
}else if(word == "i've"){
return "I've";
}else if(word == "i'll"){
return "I'll";
}else{
return word;
}
}
string Vocab::Vector2StringV2(vector<int> in, std::string language)
{
int i;
@ -94,6 +109,7 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
size_t found = word.find(unicodeChar);
if(found != std::string::npos){
if (combine != ""){
combine = WordFormat(combine);
if (words.size() != 0){
combine = " " + combine;
}
@ -164,6 +180,7 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
}
if (language == "en-bpe" and combine != ""){
combine = WordFormat(combine);
if (words.size() != 0){
combine = " " + combine;
}

View File

@ -23,6 +23,7 @@ class Vocab {
bool IsChinese(string ch);
void Vector2String(vector<int> in, std::vector<std::string> &preds);
string Vector2StringV2(vector<int> in, std::string language="");
string WordFormat(std::string word);
int GetIdByToken(const std::string &token);
};

View File

@ -195,11 +195,16 @@ int main(int argc, char* argv[]) {
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["model-revision"]="v1.2.4";
}else{
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.5";
}
}
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.5";
}
found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.0";
}
// modelscope