mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add timestamp smooth
This commit is contained in:
parent
172e7ac986
commit
d674c29323
@ -55,7 +55,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
FunOfflineReset(asr_handle, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
|
||||
if(result){
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
@ -69,7 +69,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
}
|
||||
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
|
||||
@ -157,7 +157,7 @@ int main(int argc, char** argv)
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
@ -294,6 +294,12 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(offline_stream->UseITN() && itn){
|
||||
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
@ -384,6 +390,12 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(offline_stream->UseITN() && itn){
|
||||
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
@ -524,6 +536,13 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(tpass_stream->UseITN() && itn){
|
||||
string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
|
||||
// TimestampSmooth
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->tpass_msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -300,10 +300,15 @@ void Paraformer::InitSegDict(const std::string &seg_dict_model) {
|
||||
|
||||
Paraformer::~Paraformer()
|
||||
{
|
||||
if(vocab)
|
||||
if(vocab){
|
||||
delete vocab;
|
||||
if(seg_dict)
|
||||
}
|
||||
if(seg_dict){
|
||||
delete seg_dict;
|
||||
}
|
||||
if(phone_set_){
|
||||
delete phone_set_;
|
||||
}
|
||||
}
|
||||
|
||||
void Paraformer::StartUtterance()
|
||||
|
||||
@ -247,6 +247,316 @@ void SplitChiEngCharacters(const std::string &input_str,
|
||||
}
|
||||
}
|
||||
|
||||
// Timestamp Smooth
|
||||
void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
|
||||
if(!TimestampIsPunctuation(str_word)){
|
||||
alignment_str1.push_front(str_word);
|
||||
}
|
||||
}
|
||||
|
||||
bool TimestampIsPunctuation(const std::string& str) {
|
||||
const std::string punctuation = u8",。?、,.?";
|
||||
for (char ch : str) {
|
||||
if (punctuation.find(ch) == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<vector<int>> ParseTimestamps(const std::string& str) {
|
||||
vector<vector<int>> timestamps;
|
||||
std::istringstream ss(str);
|
||||
std::string segment;
|
||||
|
||||
// skip first'['
|
||||
ss.ignore(1);
|
||||
|
||||
while (std::getline(ss, segment, ']')) {
|
||||
std::istringstream segmentStream(segment);
|
||||
std::string number;
|
||||
vector<int> ts;
|
||||
|
||||
// skip'['
|
||||
segmentStream.ignore(1);
|
||||
|
||||
while (std::getline(segmentStream, number, ',')) {
|
||||
ts.push_back(std::stoi(number));
|
||||
}
|
||||
if(ts.size() != 2){
|
||||
LOG(ERROR) << "ParseTimestamps Failed";
|
||||
timestamps.clear();
|
||||
return timestamps;
|
||||
}
|
||||
timestamps.push_back(ts);
|
||||
ss.ignore(1);
|
||||
}
|
||||
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
bool TimestampIsDigit(U16CHAR_T &u16) {
|
||||
return u16 >= L'0' && u16 <= L'9';
|
||||
}
|
||||
|
||||
bool TimestampIsAlpha(U16CHAR_T &u16) {
|
||||
return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
|
||||
}
|
||||
|
||||
bool TimestampIsPunctuation(U16CHAR_T &u16) {
|
||||
return (u16 >= 0x21 && u16 <= 0x2F) // 标准ASCII标点
|
||||
|| (u16 >= 0x3A && u16 <= 0x40) // 标准ASCII标点
|
||||
|| (u16 >= 0x5B && u16 <= 0x60) // 标准ASCII标点
|
||||
|| (u16 >= 0x7B && u16 <= 0x7E) // 标准ASCII标点
|
||||
|| (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点
|
||||
|| (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点
|
||||
}
|
||||
|
||||
void TimestampSplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters) {
|
||||
characters.resize(0);
|
||||
std::string eng_word = "";
|
||||
U16CHAR_T space = 0x0020;
|
||||
std::vector<U16CHAR_T> u16_buf;
|
||||
u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
|
||||
U16CHAR_T* pu16 = u16_buf.data();
|
||||
U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
|
||||
size_t ilen = input_str.size();
|
||||
size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
if (EncodeConverter::IsChineseCharacter(pu16[i])) {
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
characters.push_back((const char*)u8buf);
|
||||
} else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
characters.push_back((const char*)u8buf);
|
||||
} else if (pu16[i] == space){
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
}else{
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
eng_word += (const char*)u8buf;
|
||||
}
|
||||
}
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
}
|
||||
|
||||
std::string VectorToString(const std::vector<std::vector<int>>& vec) {
|
||||
if(vec.size() == 0){
|
||||
return "";
|
||||
}
|
||||
std::ostringstream out;
|
||||
out << "[";
|
||||
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
out << "[";
|
||||
for (size_t j = 0; j < vec[i].size(); ++j) {
|
||||
out << vec[i][j];
|
||||
if (j < vec[i].size() - 1) {
|
||||
out << ",";
|
||||
}
|
||||
}
|
||||
out << "]";
|
||||
if (i < vec.size() - 1) {
|
||||
out << ",";
|
||||
}
|
||||
}
|
||||
|
||||
out << "]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
|
||||
vector<vector<int>> timestamps_out;
|
||||
std::string timestamps_str = "";
|
||||
// process string to vector<string>
|
||||
std::vector<std::string> characters;
|
||||
funasr::TimestampSplitChiEngCharacters(text, characters);
|
||||
|
||||
std::vector<std::string> characters_itn;
|
||||
funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
|
||||
|
||||
//convert string to vector<vector<int>>
|
||||
vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
|
||||
|
||||
if (timestamps.size() == 0){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
// edit distance
|
||||
int m = characters.size();
|
||||
int n = characters_itn.size();
|
||||
std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
|
||||
|
||||
// init
|
||||
for (int i = 0; i <= m; ++i) {
|
||||
dp[i][0] = i;
|
||||
}
|
||||
for (int j = 0; j <= n; ++j) {
|
||||
dp[0][j] = j;
|
||||
}
|
||||
|
||||
// dp
|
||||
for (int i = 1; i <= m; ++i) {
|
||||
for (int j = 1; j <= n; ++j) {
|
||||
if (characters[i - 1] == characters_itn[j - 1]) {
|
||||
dp[i][j] = dp[i - 1][j - 1];
|
||||
} else {
|
||||
dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backtrack
|
||||
std::deque<string> alignment_str1, alignment_str2;
|
||||
int i = m, j = n;
|
||||
while (i > 0 || j > 0) {
|
||||
if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
alignment_str2.push_front("");
|
||||
i -= 1;
|
||||
} else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
|
||||
alignment_str1.push_front("");
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
j -= 1;
|
||||
} else{
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// smooth
|
||||
int itn_count = 0;
|
||||
int idx_tp = 0;
|
||||
int idx_itn = 0;
|
||||
vector<vector<int>> timestamps_tmp;
|
||||
for(int index = 0; index < alignment_str1.size(); index++){
|
||||
if (alignment_str1[index] == alignment_str2[index]){
|
||||
bool subsidy = false;
|
||||
if (itn_count > 0 && timestamps_tmp.size() == 0){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_tmp.push_back(timestamps[idx_tp]);
|
||||
subsidy = true;
|
||||
itn_count++;
|
||||
}
|
||||
|
||||
if (timestamps_tmp.size() > 0){
|
||||
if (itn_count > 0){
|
||||
int begin = timestamps_tmp[0][0];
|
||||
int end = timestamps_tmp.back()[1];
|
||||
int total_time = end - begin;
|
||||
int interval = total_time / itn_count;
|
||||
for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
|
||||
vector<int> ts;
|
||||
ts.push_back(begin + interval*idx_cnt);
|
||||
if(idx_cnt == itn_count-1){
|
||||
ts.push_back(end);
|
||||
}else {
|
||||
ts.push_back(begin + interval*(idx_cnt + 1));
|
||||
}
|
||||
timestamps_out.push_back(ts);
|
||||
}
|
||||
}
|
||||
timestamps_tmp.clear();
|
||||
}
|
||||
if(!subsidy){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_out.push_back(timestamps[idx_tp]);
|
||||
}
|
||||
idx_tp++;
|
||||
itn_count = 0;
|
||||
}else{
|
||||
if (!alignment_str1[index].empty()){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_tmp.push_back(timestamps[idx_tp]);
|
||||
idx_tp++;
|
||||
}
|
||||
if (!alignment_str2[index].empty()){
|
||||
itn_count++;
|
||||
}
|
||||
}
|
||||
// count length of itn
|
||||
if (!alignment_str2[index].empty()){
|
||||
idx_itn++;
|
||||
}
|
||||
}
|
||||
{
|
||||
if (itn_count > 0 && timestamps_tmp.size() == 0){
|
||||
if (timestamps_out.size() > 0){
|
||||
timestamps_tmp.push_back(timestamps_out.back());
|
||||
itn_count++;
|
||||
timestamps_out.pop_back();
|
||||
} else{
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
|
||||
return timestamps_str;
|
||||
}
|
||||
}
|
||||
|
||||
if (timestamps_tmp.size() > 0){
|
||||
if (itn_count > 0){
|
||||
int begin = timestamps_tmp[0][0];
|
||||
int end = timestamps_tmp.back()[1];
|
||||
int total_time = end - begin;
|
||||
int interval = total_time / itn_count;
|
||||
for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
|
||||
vector<int> ts;
|
||||
ts.push_back(begin + interval*idx_cnt);
|
||||
if(idx_cnt == itn_count-1){
|
||||
ts.push_back(end);
|
||||
}else {
|
||||
ts.push_back(begin + interval*(idx_cnt + 1));
|
||||
}
|
||||
timestamps_out.push_back(ts);
|
||||
}
|
||||
}
|
||||
timestamps_tmp.clear();
|
||||
}
|
||||
}
|
||||
if(timestamps_out.size() != idx_itn){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
timestamps_str = VectorToString(timestamps_out);
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
std::vector<std::string> split(const std::string &s, char delim) {
|
||||
std::vector<std::string> elems;
|
||||
std::stringstream ss(s);
|
||||
@ -333,12 +643,23 @@ string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>
|
||||
int sub_word = !(word.find("@@") == string::npos);
|
||||
// process word start and middle part
|
||||
if (sub_word) {
|
||||
combine += word.erase(word.length() - 2);
|
||||
if(!is_combining){
|
||||
begin = timestamp_list[i][0];
|
||||
// if badcase: lo@@ chinese
|
||||
if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
|
||||
word = word.erase(word.length() - 2) + " ";
|
||||
if (is_combining) {
|
||||
combine += word;
|
||||
is_combining = false;
|
||||
word = combine;
|
||||
combine = "";
|
||||
}
|
||||
}else{
|
||||
combine += word.erase(word.length() - 2);
|
||||
if(!is_combining){
|
||||
begin = timestamp_list[i][0];
|
||||
}
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
// process word end part
|
||||
else if (is_combining) {
|
||||
@ -669,4 +990,9 @@ void ExtractHws(string hws_file, unordered_map<string, int> &hws_map, string& nn
|
||||
ifs_hws.close();
|
||||
}
|
||||
|
||||
void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -3,11 +3,13 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include "tensor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace funasr {
|
||||
typedef unsigned short U16CHAR_T;
|
||||
extern float *LoadParams(const char *filename);
|
||||
|
||||
extern void SaveDataFile(const char *filename, void *data, uint32_t len);
|
||||
@ -35,6 +37,16 @@ void KeepChineseCharacterAndSplit(const std::string &input_str,
|
||||
std::vector<std::string> &chinese_characters);
|
||||
void SplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters);
|
||||
void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word);
|
||||
vector<vector<int>> ParseTimestamps(const std::string& str);
|
||||
bool TimestampIsDigit(U16CHAR_T &u16);
|
||||
bool TimestampIsAlpha(U16CHAR_T &u16);
|
||||
bool TimestampIsPunctuation(U16CHAR_T &u16);
|
||||
bool TimestampIsPunctuation(const std::string& str);
|
||||
void TimestampSplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters);
|
||||
std::string VectorToString(const std::vector<std::vector<int>>& vec);
|
||||
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
|
||||
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
|
||||
@ -120,8 +120,8 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
std::string combine = "";
|
||||
std::string unicodeChar = "▁";
|
||||
|
||||
for (auto it = in.begin(); it != in.end(); it++) {
|
||||
string word = vocab[*it];
|
||||
for (i=0; i<in.size(); i++){
|
||||
string word = vocab[in[i]];
|
||||
// step1 space character skips
|
||||
if (word == "<s>" || word == "</s>" || word == "<unk>")
|
||||
continue;
|
||||
@ -146,9 +146,20 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
int sub_word = !(word.find("@@") == string::npos);
|
||||
// process word start and middle part
|
||||
if (sub_word) {
|
||||
combine += word.erase(word.length() - 2);
|
||||
is_combining = true;
|
||||
continue;
|
||||
// if badcase: lo@@ chinese
|
||||
if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
|
||||
word = word.erase(word.length() - 2) + " ";
|
||||
if (is_combining) {
|
||||
combine += word;
|
||||
is_combining = false;
|
||||
word = combine;
|
||||
combine = "";
|
||||
}
|
||||
}else{
|
||||
combine += word.erase(word.length() - 2);
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// process word end part
|
||||
else if (is_combining) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user