mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_knf' of https://github.com/alibaba-damo-academy/FunASR into dev_knf
This commit is contained in:
commit
0bb5d87d1e
@ -18,6 +18,7 @@ void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn
|
||||
|
||||
read_model(vad_model);
|
||||
load_cmvn(vad_cmvn.c_str());
|
||||
init_cache();
|
||||
|
||||
fbank_opts.frame_opts.dither = 0;
|
||||
fbank_opts.mel_opts.num_bins = 80;
|
||||
@ -105,20 +106,18 @@ void FsmnVad::Forward(
|
||||
}
|
||||
Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
|
||||
// cache node {batch,128,19,1}
|
||||
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
|
||||
std::vector<float> cache_feats(128 * 19 * 1, 0);
|
||||
Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
|
||||
|
||||
// 3. Put nodes into onnx input vector
|
||||
std::vector<Ort::Value> vad_inputs;
|
||||
vad_inputs.emplace_back(std::move(vad_feats_ort));
|
||||
// 4 caches
|
||||
for (int i = 0; i < 4; i++) {
|
||||
// cache node {batch,128,19,1}
|
||||
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
|
||||
for (int i = 0; i < in_cache_.size(); i++) {
|
||||
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
|
||||
memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
|
||||
memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
|
||||
}
|
||||
|
||||
// 4. Onnx infer
|
||||
std::vector<Ort::Value> vad_ort_outputs;
|
||||
try {
|
||||
@ -143,6 +142,12 @@ void FsmnVad::Forward(
|
||||
memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
|
||||
sizeof(float) * output_dim);
|
||||
}
|
||||
|
||||
// get 4 caches outputs,each size is 128*19
|
||||
for (int i = 1; i < 5; i++) {
|
||||
float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
|
||||
memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -252,6 +257,17 @@ FsmnVad::infer(const std::vector<float> &waves) {
|
||||
|
||||
}
|
||||
|
||||
void FsmnVad::init_cache(){
|
||||
std::vector<float> cache_feats(128 * 19 * 1, 0);
|
||||
for (int i=0;i<4;i++){
|
||||
in_cache_.emplace_back(cache_feats);
|
||||
}
|
||||
};
|
||||
|
||||
void FsmnVad::Reset(){
|
||||
in_cache_.clear();
|
||||
init_cache();
|
||||
};
|
||||
|
||||
void FsmnVad::test() {
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ public:
|
||||
float vad_speech_noise_thres);
|
||||
|
||||
std::vector<std::vector<int>> infer(const std::vector<float> &waves);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
|
||||
@ -35,12 +36,15 @@ private:
|
||||
std::vector<std::vector<float>> *out_prob);
|
||||
|
||||
void load_cmvn(const char *filename);
|
||||
void init_cache();
|
||||
|
||||
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options_;
|
||||
std::vector<const char *> vad_in_names_;
|
||||
std::vector<const char *> vad_out_names_;
|
||||
std::vector<std::vector<float>> in_cache_;
|
||||
|
||||
knf::FbankOptions fbank_opts;
|
||||
std::vector<float> means_list;
|
||||
std::vector<float> vars_list;
|
||||
|
||||
133
funasr/runtime/onnxruntime/src/OnlineFeature.cc
Normal file
133
funasr/runtime/onnxruntime/src/OnlineFeature.cc
Normal file
@ -0,0 +1,133 @@
|
||||
//
|
||||
// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23.
|
||||
//
|
||||
|
||||
#include "OnlineFeature.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
|
||||
std::vector<std::vector<float>> cmvns)
|
||||
: sample_rate_(sample_rate),
|
||||
fbank_opts_(std::move(fbank_opts)),
|
||||
lfr_m_(lfr_m),
|
||||
lfr_n_(lfr_n),
|
||||
cmvns_(std::move(cmvns)) {
|
||||
frame_sample_length_ = sample_rate_ / 1000 * 25;;
|
||||
frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
|
||||
}
|
||||
|
||||
void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
|
||||
vector<float> waves, bool input_finished) {
|
||||
input_finished_ = input_finished;
|
||||
onlineFbank(vad_feats, waves);
|
||||
// cache deal & online lfr,cmvn
|
||||
if (vad_feats.size() > 0) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
|
||||
}
|
||||
if (lfr_splice_cache_.empty()) {
|
||||
for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
|
||||
lfr_splice_cache_.emplace_back(vad_feats[0]);
|
||||
}
|
||||
}
|
||||
if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
|
||||
vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
|
||||
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
|
||||
int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
|
||||
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
|
||||
int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
|
||||
waves.begin() + frame_from_waves * frame_shift_sample_length_);
|
||||
int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
|
||||
waves.erase(waves.begin() + sample_length, waves.end());
|
||||
} else {
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
|
||||
}
|
||||
|
||||
} else {
|
||||
if (input_finished_) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves = reserve_waveforms_;
|
||||
}
|
||||
vad_feats = lfr_splice_cache_;
|
||||
OnlineLfrCmvn(vad_feats);
|
||||
reset_cache();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
|
||||
vector<vector<float>> out_feats;
|
||||
int T = vad_feats.size();
|
||||
int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
|
||||
int lfr_splice_frame_idxs = T_lrf;
|
||||
vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
if (lfr_m_ <= T - i * lfr_n_) {
|
||||
for (int j = 0; j < lfr_m_; j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
} else {
|
||||
if (input_finished_) {
|
||||
int num_padding = lfr_m_ - (T - i * lfr_n_);
|
||||
for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
|
||||
p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
|
||||
}
|
||||
for (int j = 0; j < num_padding; j++) {
|
||||
p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
} else {
|
||||
lfr_splice_frame_idxs = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
|
||||
lfr_splice_cache_.clear();
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
|
||||
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < cmvns_[0].size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
|
||||
}
|
||||
}
|
||||
vad_feats = out_feats;
|
||||
return lfr_splice_frame_idxs;
|
||||
}
|
||||
|
||||
void OnlineFeature::onlineFbank(vector<std::vector<float>> &vad_feats,
|
||||
vector<float> &waves) {
|
||||
|
||||
knf::OnlineFbank fbank(fbank_opts_);
|
||||
// cache merge
|
||||
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
|
||||
int frame_number = compute_frame_num(waves.size(), frame_sample_length_, frame_shift_sample_length_);
|
||||
// Send the audio after the last frame shift position to the cache
|
||||
input_cache_.clear();
|
||||
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
|
||||
if (frame_number == 0) {
|
||||
return;
|
||||
}
|
||||
// Delete audio that haven't undergone fbank processing
|
||||
waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
|
||||
|
||||
fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
|
||||
int32_t frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
vad_feats.emplace_back(frame_vector);
|
||||
}
|
||||
|
||||
}
|
||||
59
funasr/runtime/onnxruntime/src/OnlineFeature.h
Normal file
59
funasr/runtime/onnxruntime/src/OnlineFeature.h
Normal file
@ -0,0 +1,59 @@
|
||||
//
|
||||
// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23.
|
||||
//
|
||||
|
||||
|
||||
#include "kaldi-native-fbank/csrc/feature-fbank.h"
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
||||
class OnlineFeature {
|
||||
|
||||
public:
|
||||
OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
|
||||
std::vector<std::vector<float>> cmvns_);
|
||||
|
||||
void extractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
|
||||
|
||||
|
||||
private:
|
||||
void onlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
|
||||
|
||||
int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
|
||||
|
||||
static int compute_frame_num(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
|
||||
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
|
||||
|
||||
if (frame_num >= 1 && sample_length >= frame_sample_length)
|
||||
return frame_num;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
void reset_cache() {
|
||||
reserve_waveforms_.clear();
|
||||
input_cache_.clear();
|
||||
lfr_splice_cache_.clear();
|
||||
input_finished_ = false;
|
||||
|
||||
}
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
// The reserved waveforms by fbank
|
||||
std::vector<float> reserve_waveforms_;
|
||||
// waveforms reserved after last shift position
|
||||
std::vector<float> input_cache_;
|
||||
// lfr reserved cache
|
||||
std::vector<std::vector<float>> lfr_splice_cache_;
|
||||
std::vector<std::vector<float>> cmvns_;
|
||||
|
||||
int sample_rate_ = 16000;
|
||||
int frame_sample_length_ = sample_rate_ / 1000 * 25;;
|
||||
int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
|
||||
int lfr_m_;
|
||||
int lfr_n_;
|
||||
bool input_finished_ = false;
|
||||
|
||||
};
|
||||
Loading…
Reference in New Issue
Block a user