#include "precomp.h" namespace funasr { float *LoadParams(const char *filename) { FILE *fp; fp = fopen(filename, "rb"); fseek(fp, 0, SEEK_END); uint32_t nFileLen = ftell(fp); fseek(fp, 0, SEEK_SET); float *params_addr = (float *)AlignedMalloc(32, nFileLen); int n = fread(params_addr, 1, nFileLen, fp); fclose(fp); return params_addr; } int ValAlign(int val, int align) { float tmp = ceil((float)val / (float)align) * (float)align; return (int)tmp; } void DispParams(float *din, int size) { int i; for (i = 0; i < size; i++) { printf("%f ", din[i]); } printf("\n"); } void SaveDataFile(const char *filename, void *data, uint32_t len) { FILE *fp; fp = fopen(filename, "wb+"); fwrite(data, 1, len, fp); fclose(fp); } void BasicNorm(Tensor *&din, float norm) { int Tmax = din->size[2]; int i, j; for (i = 0; i < Tmax; i++) { float sum = 0; for (j = 0; j < 512; j++) { int ii = i * 512 + j; sum += din->buff[ii] * din->buff[ii]; } float mean = sqrt(sum / 512 + norm); for (j = 0; j < 512; j++) { int ii = i * 512 + j; din->buff[ii] = din->buff[ii] / mean; } } } void FindMax(float *din, int len, float &max_val, int &max_idx) { int i; max_val = -INFINITY; max_idx = -1; for (i = 0; i < len; i++) { if (din[i] > max_val) { max_val = din[i]; max_idx = i; } } } string PathAppend(const string &p1, const string &p2) { char sep = '/'; string tmp = p1; #ifdef _WIN32 sep = '\\'; #endif if (p1[p1.length()-1] != sep) { // Need to add a tmp += sep; // path separator return (tmp + p2); } else return (p1 + p2); } void Relu(Tensor *din) { int i; for (i = 0; i < din->buff_size; i++) { float val = din->buff[i]; din->buff[i] = val < 0 ? 0 : val; } } void Swish(Tensor *din) { int i; for (i = 0; i < din->buff_size; i++) { float val = din->buff[i]; din->buff[i] = val / (1 + exp(-val)); } } void Sigmoid(Tensor *din) { int i; for (i = 0; i < din->buff_size; i++) { float val = din->buff[i]; din->buff[i] = 1 / (1 + exp(-val)); } } void DoubleSwish(Tensor *din) { int i; for (i = 0; i < din->buff_size; i++) { float val = din->buff[i]; din->buff[i] = val / (1 + exp(-val + 1)); } } void Softmax(float *din, int mask, int len) { float *tmp = (float *)malloc(mask * sizeof(float)); int i; float sum = 0; float max = -INFINITY; for (i = 0; i < mask; i++) { max = max < din[i] ? din[i] : max; } for (i = 0; i < mask; i++) { tmp[i] = exp(din[i] - max); sum += tmp[i]; } for (i = 0; i < mask; i++) { din[i] = tmp[i] / sum; } free(tmp); for (i = mask; i < len; i++) { din[i] = 0; } } void LogSoftmax(float *din, int len) { float *tmp = (float *)malloc(len * sizeof(float)); int i; float sum = 0; for (i = 0; i < len; i++) { tmp[i] = exp(din[i]); sum += tmp[i]; } for (i = 0; i < len; i++) { din[i] = log(tmp[i] / sum); } free(tmp); } void Glu(Tensor *din, Tensor *dout) { int mm = din->buff_size / 1024; int i, j; for (i = 0; i < mm; i++) { for (j = 0; j < 512; j++) { int in_off = i * 1024 + j; int out_off = i * 512 + j; float a = din->buff[in_off]; float b = din->buff[in_off + 512]; dout->buff[out_off] = a / (1 + exp(-b)); } } } bool is_target_file(const std::string& filename, const std::string target) { std::size_t pos = filename.find_last_of("."); if (pos == std::string::npos) { return false; } std::string extension = filename.substr(pos + 1); return (extension == target); } void KeepChineseCharacterAndSplit(const std::string &input_str, std::vector &chinese_characters) { chinese_characters.resize(0); std::vector u16_buf; u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1)); U16CHAR_T* pu16 = u16_buf.data(); U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data(); size_t ilen = input_str.size(); size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1); for (size_t i = 0; i < len; i++) { if (EncodeConverter::IsChineseCharacter(pu16[i])) { U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; chinese_characters.push_back((const char*)u8buf); } } } void SplitChiEngCharacters(const std::string &input_str, std::vector &characters) { characters.resize(0); std::string eng_word = ""; U16CHAR_T space = 0x0020; std::vector u16_buf; u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1)); U16CHAR_T* pu16 = u16_buf.data(); U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data(); size_t ilen = input_str.size(); size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1); for (size_t i = 0; i < len; i++) { if (EncodeConverter::IsChineseCharacter(pu16[i])) { if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; characters.push_back((const char*)u8buf); } else if (pu16[i] == space){ if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } }else{ U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; eng_word += (const char*)u8buf; } } if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } } // Timestamp Smooth void TimestampAdd(std::deque &alignment_str1, std::string str_word){ if(!TimestampIsPunctuation(str_word)){ alignment_str1.push_front(str_word); } } bool TimestampIsPunctuation(const std::string& str) { const std::string punctuation = u8",。?、,.?"; for (char ch : str) { if (punctuation.find(ch) == std::string::npos) { return false; } } return true; } vector> ParseTimestamps(const std::string& str) { vector> timestamps; std::istringstream ss(str); std::string segment; // skip first'[' ss.ignore(1); while (std::getline(ss, segment, ']')) { std::istringstream segmentStream(segment); std::string number; vector ts; // skip'[' segmentStream.ignore(1); while (std::getline(segmentStream, number, ',')) { ts.push_back(std::stoi(number)); } if(ts.size() != 2){ LOG(ERROR) << "ParseTimestamps Failed"; timestamps.clear(); return timestamps; } timestamps.push_back(ts); ss.ignore(1); } return timestamps; } bool TimestampIsDigit(U16CHAR_T &u16) { return u16 >= L'0' && u16 <= L'9'; } bool TimestampIsAlpha(U16CHAR_T &u16) { return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z'); } bool TimestampIsPunctuation(U16CHAR_T &u16) { return (u16 >= 0x21 && u16 <= 0x2F) // 标准ASCII标点 || (u16 >= 0x3A && u16 <= 0x40) // 标准ASCII标点 || (u16 >= 0x5B && u16 <= 0x60) // 标准ASCII标点 || (u16 >= 0x7B && u16 <= 0x7E) // 标准ASCII标点 || (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点 || (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点 } void TimestampSplitChiEngCharacters(const std::string &input_str, std::vector &characters) { characters.resize(0); std::string eng_word = ""; U16CHAR_T space = 0x0020; std::vector u16_buf; u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1)); U16CHAR_T* pu16 = u16_buf.data(); U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data(); size_t ilen = input_str.size(); size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1); for (size_t i = 0; i < len; i++) { if (EncodeConverter::IsChineseCharacter(pu16[i])) { if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; characters.push_back((const char*)u8buf); } else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){ if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; characters.push_back((const char*)u8buf); } else if (pu16[i] == space){ if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } }else{ U8CHAR_T u8buf[4]; size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); u8buf[n] = '\0'; eng_word += (const char*)u8buf; } } if(!eng_word.empty()){ characters.push_back(eng_word); eng_word = ""; } } std::string VectorToString(const std::vector>& vec) { if(vec.size() == 0){ return ""; } std::ostringstream out; out << "["; for (size_t i = 0; i < vec.size(); ++i) { out << "["; for (size_t j = 0; j < vec[i].size(); ++j) { out << vec[i][j]; if (j < vec[i].size() - 1) { out << ","; } } out << "]"; if (i < vec.size() - 1) { out << ","; } } out << "]"; return out.str(); } std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){ vector> timestamps_out; std::string timestamps_str = ""; // process string to vector std::vector characters; funasr::TimestampSplitChiEngCharacters(text, characters); std::vector characters_itn; funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn); //convert string to vector> vector> timestamps = funasr::ParseTimestamps(str_time); if (timestamps.size() == 0){ LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero"; return timestamps_str; } // edit distance int m = characters.size(); int n = characters_itn.size(); std::vector> dp(m + 1, std::vector(n + 1, 0)); // init for (int i = 0; i <= m; ++i) { dp[i][0] = i; } for (int j = 0; j <= n; ++j) { dp[0][j] = j; } // dp for (int i = 1; i <= m; ++i) { for (int j = 1; j <= n; ++j) { if (characters[i - 1] == characters_itn[j - 1]) { dp[i][j] = dp[i - 1][j - 1]; } else { dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1; } } } // backtrack std::deque alignment_str1, alignment_str2; int i = m, j = n; while (i > 0 || j > 0) { if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) { funasr::TimestampAdd(alignment_str1, characters[i - 1]); funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]); i -= 1; j -= 1; } else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) { funasr::TimestampAdd(alignment_str1, characters[i - 1]); alignment_str2.push_front(""); i -= 1; } else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) { alignment_str1.push_front(""); funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]); j -= 1; } else{ funasr::TimestampAdd(alignment_str1, characters[i - 1]); funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]); i -= 1; j -= 1; } } // smooth int itn_count = 0; int idx_tp = 0; int idx_itn = 0; vector> timestamps_tmp; for(int index = 0; index < alignment_str1.size(); index++){ if (alignment_str1[index] == alignment_str2[index]){ bool subsidy = false; if (itn_count > 0 && timestamps_tmp.size() == 0){ if(idx_tp >= timestamps.size()){ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. "; return timestamps_str; } timestamps_tmp.push_back(timestamps[idx_tp]); subsidy = true; itn_count++; } if (timestamps_tmp.size() > 0){ if (itn_count > 0){ int begin = timestamps_tmp[0][0]; int end = timestamps_tmp.back()[1]; int total_time = end - begin; int interval = total_time / itn_count; for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){ vector ts; ts.push_back(begin + interval*idx_cnt); if(idx_cnt == itn_count-1){ ts.push_back(end); }else { ts.push_back(begin + interval*(idx_cnt + 1)); } timestamps_out.push_back(ts); } } timestamps_tmp.clear(); } if(!subsidy){ if(idx_tp >= timestamps.size()){ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. "; return timestamps_str; } timestamps_out.push_back(timestamps[idx_tp]); } idx_tp++; itn_count = 0; }else{ if (!alignment_str1[index].empty()){ if(idx_tp >= timestamps.size()){ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. "; return timestamps_str; } timestamps_tmp.push_back(timestamps[idx_tp]); idx_tp++; } if (!alignment_str2[index].empty()){ itn_count++; } } // count length of itn if (!alignment_str2[index].empty()){ idx_itn++; } } { if (itn_count > 0 && timestamps_tmp.size() == 0){ if (timestamps_out.size() > 0){ timestamps_tmp.push_back(timestamps_out.back()); itn_count++; timestamps_out.pop_back(); } else{ LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp."; return timestamps_str; } } if (timestamps_tmp.size() > 0){ if (itn_count > 0){ int begin = timestamps_tmp[0][0]; int end = timestamps_tmp.back()[1]; int total_time = end - begin; int interval = total_time / itn_count; for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){ vector ts; ts.push_back(begin + interval*idx_cnt); if(idx_cnt == itn_count-1){ ts.push_back(end); }else { ts.push_back(begin + interval*(idx_cnt + 1)); } timestamps_out.push_back(ts); } } timestamps_tmp.clear(); } } if(timestamps_out.size() != idx_itn){ LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched."; return timestamps_str; } timestamps_str = VectorToString(timestamps_out); return timestamps_str; } std::vector split(const std::string &s, char delim) { std::vector elems; std::stringstream ss(s); std::string item; while(std::getline(ss, item, delim)) { elems.push_back(item); } return elems; } template void PrintMat(const std::vector> &mat, const std::string &name) { std::cout << name << ":" << std::endl; for (auto item : mat) { for (auto item_ : item) { std::cout << item_ << " "; } std::cout << std::endl; } } size_t Utf8ToCharset(const std::string &input, std::vector &output) { std::string ch; for (size_t i = 0, len = 0; i != input.length(); i += len) { unsigned char byte = (unsigned)input[i]; if (byte >= 0xFC) // lenght 6 len = 6; else if (byte >= 0xF8) len = 5; else if (byte >= 0xF0) len = 4; else if (byte >= 0xE0) len = 3; else if (byte >= 0xC0) len = 2; else len = 1; ch = input.substr(i, len); output.push_back(ch); } return output.size(); } int Str2IntFunc(string str) { const char *ch_array = str.c_str(); if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) || ((ch_array[2] & 0xc0) != 0x80)) return 0; int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) | (ch_array[2] & 0x3f); return val; } bool IsChinese(string ch) { if (ch.size() != 3) { return false; } int unicode = Str2IntFunc(ch); if (unicode >= 19968 && unicode <= 40959) { return true; } return false; } string PostProcess(std::vector &raw_char, std::vector> ×tamp_list){ std::vector> timestamp_merge; int i; list words; int is_pre_english = false; int pre_english_len = 0; int is_combining = false; string combine = ""; float begin=-1; for (i=0; i" || word == "" || word == "") continue; // step2 combie phoneme to full word { int sub_word = !(word.find("@@") == string::npos); // process word start and middle part if (sub_word) { // if badcase: lo@@ chinese if (i == raw_char.size()-1 || i vec = {begin, timestamp_list[i][1]}; timestamp_merge.emplace_back(vec); begin = -1; pre_english_len = word.size(); } // pre word is english word else { // single letter turn to upper case // if (word.size() == 1) { // word[0] = word[0] - 32; // } if (pre_english_len > 1) { words.push_back(" "); words.push_back(word); begin = (begin==-1)?timestamp_list[i][0]:begin; std::vector vec = {begin, timestamp_list[i][1]}; timestamp_merge.emplace_back(vec); begin = -1; pre_english_len = word.size(); } else { // if (word.size() > 1) { // words.push_back(" "); // } words.push_back(" "); words.push_back(word); begin = (begin==-1)?timestamp_list[i][0]:begin; std::vector vec = {begin, timestamp_list[i][1]}; timestamp_merge.emplace_back(vec); begin = -1; pre_english_len = word.size(); } } is_pre_english = true; } } } string stamp_str=""; for (i=0; i& us_alphas, std::vector us_cif_peak, std::vector& char_list, std::string &res_str, std::vector> ×tamp_vec, float begin_time, float total_offset){ if (char_list.empty()) { return ; } const float START_END_THRESHOLD = 5.0; const float MAX_TOKEN_DURATION = 30.0; const float TIME_RATE = 10.0 * 6 / 1000 / 3; // 3 times upsampled, cif_peak is flattened into a 1D array std::vector cif_peak = us_cif_peak; int num_frames = cif_peak.size(); if (char_list.back() == "") { char_list.pop_back(); } if (char_list.empty()) { return ; } vector> timestamp_list; vector new_char_list; vector fire_place; // for bicif model trained with large data, cif2 actually fires when a character starts // so treat the frames between two peaks as the duration of the former token for (int i = 0; i < num_frames; i++) { if (cif_peak[i] > 1.0 - 1e-4) { fire_place.push_back(i + total_offset); } } int num_peak = fire_place.size(); if(num_peak != (int)char_list.size() + 1){ float sum = std::accumulate(us_alphas.begin(), us_alphas.end(), 0.0f); float scale = sum/((int)char_list.size() + 1); if(scale == 0){ return; } cif_peak.clear(); sum = 0.0; for(auto &alpha:us_alphas){ alpha = alpha/scale; sum += alpha; cif_peak.emplace_back(sum); if(sum>=1.0 - 1e-4){ sum -=(1.0 - 1e-4); } } fire_place.clear(); for (int i = 0; i < num_frames; i++) { if (cif_peak[i] > 1.0 - 1e-4) { fire_place.push_back(i + total_offset); } } } num_peak = fire_place.size(); if(fire_place.size() == 0){ return; } // begin silence if (fire_place[0] > START_END_THRESHOLD) { new_char_list.push_back(""); timestamp_list.push_back({0.0, fire_place[0] * TIME_RATE}); } // tokens timestamp for (int i = 0; i < num_peak - 1; i++) { new_char_list.push_back(char_list[i]); if (i == num_peak - 2 || MAX_TOKEN_DURATION < 0 || fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION) { timestamp_list.push_back({fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE}); } else { // cut the duration to token and sil of the 0-weight frames last long float _split = fire_place[i] + MAX_TOKEN_DURATION; timestamp_list.push_back({fire_place[i] * TIME_RATE, _split * TIME_RATE}); timestamp_list.push_back({_split * TIME_RATE, fire_place[i + 1] * TIME_RATE}); new_char_list.push_back(""); } } // tail token and end silence if(timestamp_list.size()==0){ LOG(ERROR)<<"timestamp_list's size is 0!"; return; } if (num_frames - fire_place.back() > START_END_THRESHOLD) { float _end = (num_frames + fire_place.back()) / 2.0; timestamp_list.back()[1] = _end * TIME_RATE; timestamp_list.push_back({_end * TIME_RATE, num_frames * TIME_RATE}); new_char_list.push_back(""); } else { timestamp_list.back()[1] = num_frames * TIME_RATE; } if (begin_time) { // add offset time in model with vad for (auto& timestamp : timestamp_list) { timestamp[0] += begin_time / 1000.0; timestamp[1] += begin_time / 1000.0; } } assert(new_char_list.size() == timestamp_list.size()); for (int i = 0; i < (int)new_char_list.size(); i++) { res_str += new_char_list[i] + " " + to_string(timestamp_list[i][0]) + " " + to_string(timestamp_list[i][1]) + ";"; } for (int i = 0; i < (int)new_char_list.size(); i++) { if(new_char_list[i] != ""){ timestamp_vec.push_back(timestamp_list[i]); } } } bool IsTargetFile(const std::string& filename, const std::string target) { std::size_t pos = filename.find_last_of("."); if (pos == std::string::npos) { return false; } std::string extension = filename.substr(pos + 1); return (extension == target); } void Trim(std::string *str) { const char *white_chars = " \t\n\r\f\v"; std::string::size_type pos = str->find_last_not_of(white_chars); if (pos != std::string::npos) { str->erase(pos + 1); pos = str->find_first_not_of(white_chars); if (pos != std::string::npos) str->erase(0, pos); } else { str->erase(str->begin(), str->end()); } } void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector *out) { size_t start = 0, found = 0, end = full.size(); out->clear(); while (found != std::string::npos) { found = full.find_first_of(delim, start); // start != end condition is for when the delimiter is at the end if (!omit_empty_strings || (found != start && start != end)) out->push_back(full.substr(start, found - start)); start = found + 1; } } void ExtractHws(string hws_file, unordered_map &hws_map) { if(hws_file.empty()){ return; } std::string line; std::ifstream ifs_hws(hws_file.c_str()); if(!ifs_hws.is_open()){ LOG(ERROR) << "Unable to open hotwords file: " << hws_file << ". If you have not set hotwords, please ignore this message."; return; } LOG(INFO) << "hotwords: "; while (getline(ifs_hws, line)) { Trim(&line); if (line.empty()) { continue; } float score = 1.0f; std::vector text; SplitStringToVector(line, " ", true, &text); if (text.size() > 1) { try{ score = std::stof(text[text.size() - 1]); }catch (std::exception const &e) { LOG(ERROR)< &hws_map, string& nn_hotwords_) { if(hws_file.empty()){ return; } std::string line; std::ifstream ifs_hws(hws_file.c_str()); if(!ifs_hws.is_open()){ LOG(ERROR) << "Unable to open hotwords file: " << hws_file << ". If you have not set hotwords, please ignore this message."; return; } LOG(INFO) << "hotwords: "; while (getline(ifs_hws, line)) { Trim(&line); if (line.empty()) { continue; } float score = 1.0f; std::vector text; SplitStringToVector(line, " ", true, &text); if (text.size() > 1) { try{ score = std::stof(text[text.size() - 1]); }catch (std::exception const &e) { LOG(ERROR)<