add batch for paraformer

This commit is contained in:
雾聪 2024-03-29 16:47:27 +08:00
parent 445b7ec47c
commit a58c3d4593
4 changed files with 94 additions and 66 deletions

View File

@ -265,34 +265,45 @@ void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
asr_feats = out_feats;
}
string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t feature_dim = lfr_m*in_feat_dim;
std::vector<std::vector<float>> asr_feats;
FbankKaldi(asr_sample_rate, din, len, asr_feats);
if(asr_feats.size() == 0){
return "";
}
LfrCmvn(asr_feats);
int32_t feat_dim = lfr_m*in_feat_dim;
int32_t num_frames = asr_feats.size();
std::vector<float> wav_feats;
for (const auto &frame_feat: asr_feats) {
wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end());
}
std::vector<vector<float>> feats_batch;
std::vector<int32_t> paraformer_length;
paraformer_length.emplace_back(num_frames);
int max_size = 0;
int max_frames = 0;
for(int index=0; index<batch_in; index++){
std::vector<std::vector<float>> asr_feats;
FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
if(asr_feats.size() != 0){
LfrCmvn(asr_feats);
}
feats_batch.emplace_back(asr_feats);
int32_t num_frames = asr_feats.size() / feature_dim;
paraformer_length.emplace_back(num_frames);
if(max_size < asr_feats.size()){
max_size = asr_feats.size();
max_frames = num_frames;
}
}
torch::NoGradGuard no_grad;
model_->eval();
// padding
std::vector<float> all_feats(batch_in * max_frames * feature_dim);
for(int index=0; index<batch_in; index++){
feats_batch[index].resize(max_size);
std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
max_frames * feature_dim * sizeof(float));
}
torch::Tensor feats =
torch::from_blob(wav_feats.data(),
{1, num_frames, feat_dim}, torch::kFloat).contiguous();
torch::from_blob(all_feats.data(),
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
{1}, torch::kInt32);
{batch_in}, torch::kInt32);
// 2. forward
#ifdef USE_GPU
@ -301,7 +312,7 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
#endif
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
string result="";
vector<std::string> results;
try {
auto outputs = model_->forward(inputs).toTuple()->elements();
torch::Tensor am_scores;
@ -314,47 +325,49 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
valid_token_lens = outputs[1].toTensor();
#endif
// timestamp
if(outputs.size() == 4){
torch::Tensor us_alphas_tensor;
torch::Tensor us_peaks_tensor;
#ifdef USE_GPU
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
#else
us_alphas_tensor = outputs[2].toTensor();
us_peaks_tensor = outputs[3].toTensor();
#endif
for(int index=0; index<batch_in; index++){
string result="";
if(outputs.size() == 4){
torch::Tensor us_alphas_tensor;
torch::Tensor us_peaks_tensor;
#ifdef USE_GPU
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
#else
us_alphas_tensor = outputs[2].toTensor();
us_peaks_tensor = outputs[3].toTensor();
#endif
int us_alphas_shape_1 = us_alphas_tensor.size(1);
float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
std::vector<float> us_alphas(us_alphas_shape_1);
for (int i = 0; i < us_alphas_shape_1; i++) {
us_alphas[i] = us_alphas_data[i];
}
int us_peaks_shape_1 = us_peaks_tensor.size(1);
float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
std::vector<float> us_peaks(us_peaks_shape_1);
for (int i = 0; i < us_peaks_shape_1; i++) {
us_peaks[i] = us_peaks_data[i];
}
if (lm_ == nullptr) {
result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
} else {
result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
std::vector<float> us_alphas(paraformer_length[index]);
for (int i = 0; i < us_alphas.size(); i++) {
us_alphas[i] = us_alphas_data[i];
}
}
}else{
if (lm_ == nullptr) {
result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
} else {
result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder);
float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
std::vector<float> us_peaks(paraformer_length[index]);
for (int i = 0; i < us_peaks.size(); i++) {
us_peaks[i] = us_peaks_data[i];
}
if (lm_ == nullptr) {
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
} else {
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
}
}
}else{
if (lm_ == nullptr) {
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
} else {
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder);
}
}
}
results.push_back(result);
}
}
catch (std::exception const &e)
@ -362,7 +375,7 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
LOG(ERROR)<<e.what();
}
return result;
return results;
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {

View File

@ -48,13 +48,15 @@ namespace funasr {
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
void SetBatchSize(int batch_size) {batch_size_ = batch_size};
int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@ -88,6 +90,7 @@ namespace funasr {
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
};
} // namespace funasr

View File

@ -462,15 +462,23 @@ void Paraformer::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
asr_feats = out_feats;
}
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
std::vector<std::string> results;
string result="";
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
if(batch_in != 1){
results.push_back(result);
return results;
}
std::vector<std::vector<float>> asr_feats;
FbankKaldi(asr_sample_rate, din, len, asr_feats);
FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
if(asr_feats.size() == 0){
return "";
results.push_back(result);
return results;
}
LfrCmvn(asr_feats);
int32_t feat_dim = lfr_m*in_feat_dim;
@ -509,7 +517,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
if (use_hotword) {
if(hw_emb.size()<=0){
LOG(ERROR) << "hw_emb is null";
return "";
results.push_back(result);
return results;
}
//PrintMat(hw_emb, "input_clas_emb");
const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@ -526,10 +535,10 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
return "";
results.push_back(result);
return results;
}
string result="";
try {
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@ -577,7 +586,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
LOG(ERROR)<<e.what();
}
return result;
results.push_back(result);
return results;
}

View File

@ -52,13 +52,14 @@ namespace funasr {
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@ -110,6 +111,7 @@ namespace funasr {
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
};
} // namespace funasr