mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
support en-bpe model
This commit is contained in:
parent
89e68b28c5
commit
4984724f6a
@ -18,6 +18,7 @@ class Model {
|
||||
virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
|
||||
virtual void InitSegDict(const std::string &seg_dict_model){};
|
||||
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
|
||||
virtual std::string GetLang(){return "";};
|
||||
};
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
|
||||
|
||||
@ -12,8 +12,8 @@ class PuncModel {
|
||||
public:
|
||||
virtual ~PuncModel(){};
|
||||
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
|
||||
virtual std::string AddPunc(const char* sz_input){return "";};
|
||||
virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){return "";};
|
||||
virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
|
||||
virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache, std::string language="zh-cn"){return "";};
|
||||
};
|
||||
|
||||
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
|
||||
|
||||
@ -50,7 +50,7 @@ CTTransformerOnline::~CTTransformerOnline()
|
||||
{
|
||||
}
|
||||
|
||||
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache)
|
||||
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language)
|
||||
{
|
||||
string strResult;
|
||||
vector<string> strOut;
|
||||
|
||||
@ -29,7 +29,7 @@ public:
|
||||
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
|
||||
~CTTransformerOnline();
|
||||
vector<int> Infer(vector<int32_t> input_data, int nCacheSize);
|
||||
string AddPunc(const char* sz_input, vector<string> &arr_cache);
|
||||
string AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language="zh-cn");
|
||||
void Transport(vector<float>& In, int nRows, int nCols);
|
||||
void VadMask(int size, int vad_pos,vector<float>& Result);
|
||||
void Triangle(int text_length, vector<float>& Result);
|
||||
|
||||
@ -46,7 +46,7 @@ CTTransformer::~CTTransformer()
|
||||
{
|
||||
}
|
||||
|
||||
string CTTransformer::AddPunc(const char* sz_input)
|
||||
string CTTransformer::AddPunc(const char* sz_input, std::string language)
|
||||
{
|
||||
string strResult;
|
||||
vector<string> strOut;
|
||||
@ -139,8 +139,28 @@ string CTTransformer::AddPunc(const char* sz_input)
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& item : NewSentenceOut)
|
||||
|
||||
for (auto& item : NewSentenceOut){
|
||||
strResult += item;
|
||||
}
|
||||
|
||||
if(language == "en-bpe"){
|
||||
std::vector<std::string> chineseSymbols;
|
||||
chineseSymbols.push_back(",");
|
||||
chineseSymbols.push_back("。");
|
||||
chineseSymbols.push_back("、");
|
||||
chineseSymbols.push_back("?");
|
||||
|
||||
std::string englishSymbols = ",.,?";
|
||||
for (size_t i = 0; i < chineseSymbols.size(); i++) {
|
||||
size_t pos = 0;
|
||||
while ((pos = strResult.find(chineseSymbols[i], pos)) != std::string::npos) {
|
||||
strResult.replace(pos, 3, 1, englishSymbols[i]);
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strResult;
|
||||
}
|
||||
|
||||
|
||||
@ -29,6 +29,6 @@ public:
|
||||
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
|
||||
~CTTransformer();
|
||||
vector<int> Infer(vector<int32_t> input_data);
|
||||
string AddPunc(const char* sz_input);
|
||||
string AddPunc(const char* sz_input, std::string language="zh-cn");
|
||||
};
|
||||
} // namespace funasr
|
||||
@ -282,7 +282,8 @@ extern "C" {
|
||||
p_result->stamp += cur_stamp + "]";
|
||||
}
|
||||
if(offline_stream->UsePunc()){
|
||||
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
|
||||
string lang = (offline_stream->asr_handle)->GetLang();
|
||||
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
#if !defined(__APPLE__)
|
||||
@ -363,7 +364,8 @@ extern "C" {
|
||||
p_result->stamp += cur_stamp + "]";
|
||||
}
|
||||
if(offline_stream->UsePunc()){
|
||||
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
|
||||
string lang = (offline_stream->asr_handle)->GetLang();
|
||||
string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
|
||||
p_result->msg = punc_res;
|
||||
}
|
||||
#if !defined(__APPLE__)
|
||||
|
||||
@ -33,7 +33,6 @@ namespace funasr {
|
||||
vector<const char*> hw_m_szInputNames;
|
||||
vector<const char*> hw_m_szOutputNames;
|
||||
bool use_hotword;
|
||||
std::string language="zh-cn";
|
||||
|
||||
public:
|
||||
Paraformer();
|
||||
@ -55,6 +54,7 @@ namespace funasr {
|
||||
string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> ×tamp_list);
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
vector<float> means_list_;
|
||||
@ -71,6 +71,8 @@ namespace funasr {
|
||||
vector<const char*> m_szInputNames;
|
||||
vector<const char*> m_szOutputNames;
|
||||
|
||||
std::string language="zh-cn";
|
||||
|
||||
// paraformer-online
|
||||
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
|
||||
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
|
||||
|
||||
@ -75,6 +75,21 @@ bool Vocab::IsChinese(string ch)
|
||||
return false;
|
||||
}
|
||||
|
||||
string Vocab::WordFormat(std::string word)
|
||||
{
|
||||
if(word == "i"){
|
||||
return "I";
|
||||
}else if(word == "i'm"){
|
||||
return "I'm";
|
||||
}else if(word == "i've"){
|
||||
return "I've";
|
||||
}else if(word == "i'll"){
|
||||
return "I'll";
|
||||
}else{
|
||||
return word;
|
||||
}
|
||||
}
|
||||
|
||||
string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
{
|
||||
int i;
|
||||
@ -94,6 +109,7 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
size_t found = word.find(unicodeChar);
|
||||
if(found != std::string::npos){
|
||||
if (combine != ""){
|
||||
combine = WordFormat(combine);
|
||||
if (words.size() != 0){
|
||||
combine = " " + combine;
|
||||
}
|
||||
@ -164,6 +180,7 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
}
|
||||
|
||||
if (language == "en-bpe" and combine != ""){
|
||||
combine = WordFormat(combine);
|
||||
if (words.size() != 0){
|
||||
combine = " " + combine;
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ class Vocab {
|
||||
bool IsChinese(string ch);
|
||||
void Vector2String(vector<int> in, std::vector<std::string> &preds);
|
||||
string Vector2StringV2(vector<int> in, std::string language="");
|
||||
string WordFormat(std::string word);
|
||||
int GetIdByToken(const std::string &token);
|
||||
};
|
||||
|
||||
|
||||
@ -195,11 +195,16 @@ int main(int argc, char* argv[]) {
|
||||
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.2.4";
|
||||
}else{
|
||||
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.0.5";
|
||||
}
|
||||
}
|
||||
|
||||
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.0.5";
|
||||
}
|
||||
|
||||
found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.0.0";
|
||||
}
|
||||
|
||||
// modelscope
|
||||
|
||||
Loading…
Reference in New Issue
Block a user