From 1b6b10ab2db6ec067ef9e1089060449f78aa4e65 Mon Sep 17 00:00:00 2001 From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com> Date: Thu, 4 May 2023 10:42:25 +0800 Subject: [PATCH] Update e2e-vad.h Fix memory and performance issues caused by long-term use of streaming VAD. 1.Remove unnecessary scores cache and only keep the latest score. 2.Remove the data_buf_all and data_buf cache, and only cache their lengths. --- funasr/runtime/onnxruntime/src/e2e-vad.h | 52 ++++++++++-------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/funasr/runtime/onnxruntime/src/e2e-vad.h b/funasr/runtime/onnxruntime/src/e2e-vad.h index 90f2635f6..02bae6296 100644 --- a/funasr/runtime/onnxruntime/src/e2e-vad.h +++ b/funasr/runtime/onnxruntime/src/e2e-vad.h @@ -1,6 +1,7 @@ /** * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. * MIT License (https://opensource.org/licenses/MIT) + * Collaborators: zhuzizyf(China Telecom Shanghai) */ #include @@ -381,10 +382,11 @@ private: int max_end_sil_frame_cnt_thresh; float speech_noise_thres; std::vector> scores; + int idx_pre_chunk = 0; bool max_time_out; std::vector decibel; - std::vector data_buf; - std::vector data_buf_all; + int data_buf_size = 0; + int data_buf_all_size = 0; std::vector waveform; void AllResetDetection() { @@ -409,10 +411,11 @@ private: max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres; speech_noise_thres = vad_opts.speech_noise_thres; scores.clear(); + idx_pre_chunk = 0; max_time_out = false; decibel.clear(); - data_buf.clear(); - data_buf_all.clear(); + int data_buf_size = 0; + int data_buf_all_size = 0; waveform.clear(); ResetDetection(); } @@ -432,18 +435,17 @@ private: void ComputeDecibel() { int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000); int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); - if (data_buf_all.empty()) { - data_buf_all = waveform; - data_buf = data_buf_all; + if (data_buf_all_size == 0) { + data_buf_all_size = waveform.size(); + data_buf_size = data_buf_all_size; } else { - data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end()); + data_buf_all_size += waveform.size(); } for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) { float sum = 0.0; for (int i = 0; i < frame_sample_length; i++) { sum += waveform[offset + i] * waveform[offset + i]; } -// float decibel = 10 * log10(sum + 0.000001); this->decibel.push_back(10 * log10(sum + 0.000001)); } } @@ -451,30 +453,17 @@ private: void ComputeScores(const std::vector> &scores) { vad_opts.nn_eval_block_size = scores.size(); frm_cnt += scores.size(); - if (this->scores.empty()) { - this->scores = scores; // the first calculation - } else { - this->scores.insert(this->scores.end(), scores.begin(), scores.end()); - } + this->scores = scores; } void PopDataBufTillFrame(int frame_idx) { int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); - int start_pos=-1; - int data_length= data_buf.size(); while (data_buf_start_frame < frame_idx) { - if (data_length >= frame_sample_length) { + if (data_buf_size >= frame_sample_length) { data_buf_start_frame += 1; - start_pos= data_buf_start_frame* frame_sample_length; - data_length=data_buf_all.size()-start_pos; - } else { - break; + data_buf_size = data_buf_all_size - data_buf_start_frame * frame_sample_length; } } - if (start_pos!=-1){ - data_buf.resize(data_length); - std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin()); - } } void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point, @@ -487,9 +476,9 @@ private: expected_sample_number += int(extra_sample); } if (end_point_is_sent_end) { - expected_sample_number = std::max(expected_sample_number, int(data_buf.size())); + expected_sample_number = std::max(expected_sample_number, data_buf_size); } - if (data_buf.size() < expected_sample_number) { + if (data_buf_size < expected_sample_number) { std::cout << "error in calling pop data_buf\n"; } if (output_data_buf.size() == 0 || first_frm_is_start_point) { @@ -510,10 +499,10 @@ private: } else { data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); } - if (data_to_pop > int(data_buf.size())) { + if (data_to_pop > data_buf_size) { std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n"; - data_to_pop = (int) data_buf.size(); - expected_sample_number = (int) data_buf.size(); + data_to_pop = data_buf_size; + expected_sample_number = data_buf_size; } cur_seg.doa = 0; for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) { @@ -619,7 +608,7 @@ private: if (sil_pdf_ids.size() > 0) { std::vector sil_pdf_scores; for (auto sil_pdf_id: sil_pdf_ids) { - sil_pdf_scores.push_back(scores[t][sil_pdf_id]); + sil_pdf_scores.push_back(scores[t - idx_pre_chunk][sil_pdf_id]); } sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0); noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio; @@ -663,6 +652,7 @@ private: frame_state = GetFrameState(frm_cnt - 1 - i); DetectOneFrame(frame_state, frm_cnt - 1 - i, false); } + idx_pre_chunk += scores.size(); return 0; }