fix eng oov hotwords

This commit is contained in:
雾聪 2023-08-22 17:32:03 +08:00
parent f70f707110
commit 639ae933aa
2 changed files with 7 additions and 0 deletions

View File

@ -719,6 +719,7 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
std::vector<int32_t> hotword_matrix;
std::vector<int32_t> lengths;
int hotword_size = 1;
int real_hw_size = 0;
if (!hotwords.empty()) {
std::vector<std::string> hotword_array = split(hotwords, ' ');
hotword_size = hotword_array.size() + 1;
@ -735,6 +736,9 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
chars.insert(chars.end(), tokens.begin(), tokens.end());
}
}
if(chars.size()==0){
continue;
}
std::vector<int32_t> hw_vector(max_hotword_len, 0);
int vector_len = std::min(max_hotword_len, (int)chars.size());
for (int i=0; i<chars.size(); i++) {
@ -743,8 +747,10 @@ std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string
}
std::cout << std::endl;
lengths.push_back(vector_len);
real_hw_size += 1;
hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end());
}
hotword_size = real_hw_size + 1;
}
std::vector<int32_t> blank_vec(max_hotword_len, 0);
blank_vec[0] = 1;

View File

@ -40,6 +40,7 @@ std::vector<std::string> SegDict::GetTokensByWord(const std::string &word) {
if (seg_dict.count(word))
return seg_dict[word];
else {
LOG(INFO)<< word <<" is OOV!";
std::vector<string> vec;
return vec;
}