diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.cc b/funasr/runtime/onnxruntime/src/FsmnVad.cc index f75ead7b5..de63225e5 100644 --- a/funasr/runtime/onnxruntime/src/FsmnVad.cc +++ b/funasr/runtime/onnxruntime/src/FsmnVad.cc @@ -18,6 +18,7 @@ void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn read_model(vad_model); load_cmvn(vad_cmvn.c_str()); + init_cache(); fbank_opts.frame_opts.dither = 0; fbank_opts.mel_opts.num_bins = 80; @@ -105,20 +106,18 @@ void FsmnVad::Forward( } Ort::Value vad_feats_ort = Ort::Value::CreateTensor( memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3); - // cache node {batch,128,19,1} - const int64_t cache_feats_shape[4] = {1, 128, 19, 1}; - std::vector cache_feats(128 * 19 * 1, 0); - Ort::Value cache_feats_ort = Ort::Value::CreateTensor( - memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4); - + // 3. Put nodes into onnx input vector std::vector vad_inputs; vad_inputs.emplace_back(std::move(vad_feats_ort)); // 4 caches - for (int i = 0; i < 4; i++) { - vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor( - memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4))); + // cache node {batch,128,19,1} + const int64_t cache_feats_shape[4] = {1, 128, 19, 1}; + for (int i = 0; i < in_cache_.size(); i++) { + vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor( + memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4))); } + // 4. Onnx infer std::vector vad_ort_outputs; try { @@ -143,6 +142,12 @@ void FsmnVad::Forward( memcpy((*out_prob)[i].data(), logp_data + i * output_dim, sizeof(float) * output_dim); } + + // get 4 caches outputs,each size is 128*19 + for (int i = 1; i < 5; i++) { + float* data = vad_ort_outputs[i].GetTensorMutableData(); + memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19); + } } @@ -252,6 +257,17 @@ FsmnVad::infer(const std::vector &waves) { } +void FsmnVad::init_cache(){ + std::vector cache_feats(128 * 19 * 1, 0); + for (int i=0;i<4;i++){ + in_cache_.emplace_back(cache_feats); + } +}; + +void FsmnVad::Reset(){ + in_cache_.clear(); + init_cache(); +}; void FsmnVad::test() { diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.h b/funasr/runtime/onnxruntime/src/FsmnVad.h index d7ec55475..78302ae6d 100644 --- a/funasr/runtime/onnxruntime/src/FsmnVad.h +++ b/funasr/runtime/onnxruntime/src/FsmnVad.h @@ -16,6 +16,7 @@ public: float vad_speech_noise_thres); std::vector> infer(const std::vector &waves); + void Reset(); private: @@ -35,12 +36,15 @@ private: std::vector> *out_prob); void load_cmvn(const char *filename); + void init_cache(); std::shared_ptr vad_session_ = nullptr; Ort::Env env_; Ort::SessionOptions session_options_; std::vector vad_in_names_; std::vector vad_out_names_; + std::vector> in_cache_; + knf::FbankOptions fbank_opts; std::vector means_list; std::vector vars_list; diff --git a/funasr/runtime/onnxruntime/src/OnlineFeature.cc b/funasr/runtime/onnxruntime/src/OnlineFeature.cc new file mode 100644 index 000000000..a2bbafd02 --- /dev/null +++ b/funasr/runtime/onnxruntime/src/OnlineFeature.cc @@ -0,0 +1,133 @@ +// +// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23. +// + +#include "OnlineFeature.h" + +#include + +OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n, + std::vector> cmvns) + : sample_rate_(sample_rate), + fbank_opts_(std::move(fbank_opts)), + lfr_m_(lfr_m), + lfr_n_(lfr_n), + cmvns_(std::move(cmvns)) { + frame_sample_length_ = sample_rate_ / 1000 * 25;; + frame_shift_sample_length_ = sample_rate_ / 1000 * 10; +} + +void OnlineFeature::extractFeats(vector> &vad_feats, + vector waves, bool input_finished) { + input_finished_ = input_finished; + onlineFbank(vad_feats, waves); + // cache deal & online lfr,cmvn + if (vad_feats.size() > 0) { + if (!reserve_waveforms_.empty()) { + waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end()); + } + if (lfr_splice_cache_.empty()) { + for (int i = 0; i < (lfr_m_ - 1) / 2; i++) { + lfr_splice_cache_.emplace_back(vad_feats[0]); + } + } + if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) { + vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end()); + int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1; + int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0; + int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats); + int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame; + reserve_waveforms_.clear(); + reserve_waveforms_.insert(reserve_waveforms_.begin(), + waves.begin() + reserve_frame_idx * frame_shift_sample_length_, + waves.begin() + frame_from_waves * frame_shift_sample_length_); + int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_; + waves.erase(waves.begin() + sample_length, waves.end()); + } else { + reserve_waveforms_.clear(); + reserve_waveforms_.insert(reserve_waveforms_.begin(), + waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end()); + lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end()); + } + + } else { + if (input_finished_) { + if (!reserve_waveforms_.empty()) { + waves = reserve_waveforms_; + } + vad_feats = lfr_splice_cache_; + OnlineLfrCmvn(vad_feats); + reset_cache(); + } + } + +} + +int OnlineFeature::OnlineLfrCmvn(vector> &vad_feats) { + vector> out_feats; + int T = vad_feats.size(); + int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_); + int lfr_splice_frame_idxs = T_lrf; + vector p; + for (int i = 0; i < T_lrf; i++) { + if (lfr_m_ <= T - i * lfr_n_) { + for (int j = 0; j < lfr_m_; j++) { + p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end()); + } + out_feats.emplace_back(p); + p.clear(); + } else { + if (input_finished_) { + int num_padding = lfr_m_ - (T - i * lfr_n_); + for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) { + p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end()); + } + for (int j = 0; j < num_padding; j++) { + p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end()); + } + out_feats.emplace_back(p); + } else { + lfr_splice_frame_idxs = i; + break; + } + } + } + lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_); + lfr_splice_cache_.clear(); + lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end()); + + // Apply cmvn + for (auto &out_feat: out_feats) { + for (int j = 0; j < cmvns_[0].size(); j++) { + out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j]; + } + } + vad_feats = out_feats; + return lfr_splice_frame_idxs; +} + +void OnlineFeature::onlineFbank(vector> &vad_feats, + vector &waves) { + + knf::OnlineFbank fbank(fbank_opts_); + // cache merge + waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end()); + int frame_number = compute_frame_num(waves.size(), frame_sample_length_, frame_shift_sample_length_); + // Send the audio after the last frame shift position to the cache + input_cache_.clear(); + input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end()); + if (frame_number == 0) { + return; + } + // Delete audio that haven't undergone fbank processing + waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end()); + + fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size()); + int32_t frames = fbank.NumFramesReady(); + for (int32_t i = 0; i != frames; ++i) { + const float *frame = fbank.GetFrame(i); + vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); + vad_feats.emplace_back(frame_vector); + } + +} diff --git a/funasr/runtime/onnxruntime/src/OnlineFeature.h b/funasr/runtime/onnxruntime/src/OnlineFeature.h new file mode 100644 index 000000000..bd613ab9f --- /dev/null +++ b/funasr/runtime/onnxruntime/src/OnlineFeature.h @@ -0,0 +1,59 @@ +// +// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23. +// + + +#include "kaldi-native-fbank/csrc/feature-fbank.h" +#include "kaldi-native-fbank/csrc/online-feature.h" +#include + +using namespace std; + +class OnlineFeature { + +public: + OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_, + std::vector> cmvns_); + + void extractFeats(vector> &vad_feats, vector waves, bool input_finished); + + +private: + void onlineFbank(vector> &vad_feats, vector &waves); + + int OnlineLfrCmvn(vector> &vad_feats); + + static int compute_frame_num(int sample_length, int frame_sample_length, int frame_shift_sample_length) { + int frame_num = static_cast((sample_length - frame_sample_length) / frame_shift_sample_length + 1); + + if (frame_num >= 1 && sample_length >= frame_sample_length) + return frame_num; + else + return 0; + } + + void reset_cache() { + reserve_waveforms_.clear(); + input_cache_.clear(); + lfr_splice_cache_.clear(); + input_finished_ = false; + + } + + knf::FbankOptions fbank_opts_; + // The reserved waveforms by fbank + std::vector reserve_waveforms_; + // waveforms reserved after last shift position + std::vector input_cache_; + // lfr reserved cache + std::vector> lfr_splice_cache_; + std::vector> cmvns_; + + int sample_rate_ = 16000; + int frame_sample_length_ = sample_rate_ / 1000 * 25;; + int frame_shift_sample_length_ = sample_rate_ / 1000 * 10; + int lfr_m_; + int lfr_n_; + bool input_finished_ = false; + +};