From 9e02d145054dd2be4da01904c7cb13a5790eb7be Mon Sep 17 00:00:00 2001 From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com> Date: Fri, 21 Apr 2023 21:36:57 +0800 Subject: [PATCH] Update e2e_vad.h Add e2e_vad online support, fix online bugs. --- funasr/runtime/onnxruntime/src/e2e_vad.h | 62 +++++++++++++++--------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/funasr/runtime/onnxruntime/src/e2e_vad.h b/funasr/runtime/onnxruntime/src/e2e_vad.h index f0c4975a6..e029dc35f 100644 --- a/funasr/runtime/onnxruntime/src/e2e_vad.h +++ b/funasr/runtime/onnxruntime/src/e2e_vad.h @@ -294,8 +294,8 @@ public: std::vector> operator()(const std::vector> &score, const std::vector &waveform, bool is_final = false, - int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9, - int sample_rate = 16000) { + bool online = false, int max_end_sil = 800, int max_single_segment_time = 15000, + float speech_noise_thres = 0.8, int sample_rate = 16000) { max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres; this->waveform = waveform; this->vad_opts.max_single_segment_time = max_single_segment_time; @@ -309,33 +309,44 @@ public: } else { DetectLastFrames(); } - // std::vector> segments; - // for (size_t batch_num = 0; batch_num < score.size(); batch_num++) { + std::vector> segment_batch; if (output_data_buf.size() > 0) { for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) { + int start_ms; + int end_ms; + if (online) { + if (!output_data_buf[i].contain_seg_start_point) { - continue; + continue; } if (!next_seg && !output_data_buf[i].contain_seg_end_point) { - continue; + continue; } - int start_ms = next_seg ? output_data_buf[i].start_ms : -1; - int end_ms; + start_ms = next_seg ? output_data_buf[i].start_ms : -1; + if (output_data_buf[i].contain_seg_end_point) { - end_ms = output_data_buf[i].end_ms; - next_seg = true; - output_data_buf_offset += 1; + end_ms = output_data_buf[i].end_ms; + next_seg = true; + output_data_buf_offset += 1; } else { - end_ms = -1; - next_seg = false; + end_ms = -1; + next_seg = false; } + } else { + if (!is_final && + (!output_data_buf[i].contain_seg_start_point || !output_data_buf[i].contain_seg_end_point)) { + continue; + } + start_ms = output_data_buf[i].start_ms; + end_ms = output_data_buf[i].end_ms; + output_data_buf_offset += 1; + } std::vector segment = {start_ms, end_ms}; segment_batch.push_back(segment); } } - // } if (is_final) { AllResetDetection(); } @@ -444,15 +455,22 @@ private: } void PopDataBufTillFrame(int frame_idx) { - while (data_buf_start_frame < frame_idx) { - int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); - if (data_buf.size() >= frame_sample_length) { - data_buf_start_frame += 1; - data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length); - } else { - break; - } + int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); + int start_pos=-1; + int data_length= data_buf.size(); + while (data_buf_start_frame < frame_idx) { + if (data_length >= frame_sample_length) { + data_buf_start_frame += 1; + start_pos= data_buf_start_frame* frame_sample_length; + data_length=data_buf_all.size()-start_pos; + } else { + break; } + } + if (start_pos!=-1){ + data_buf.resize(data_length); + std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin()); + } } void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point,