mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add batch for paraformer
This commit is contained in:
parent
445b7ec47c
commit
a58c3d4593
@ -265,34 +265,45 @@ void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
|
||||
asr_feats = out_feats;
|
||||
}
|
||||
|
||||
string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
|
||||
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
|
||||
{
|
||||
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
int32_t feature_dim = lfr_m*in_feat_dim;
|
||||
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(asr_sample_rate, din, len, asr_feats);
|
||||
if(asr_feats.size() == 0){
|
||||
return "";
|
||||
}
|
||||
LfrCmvn(asr_feats);
|
||||
int32_t feat_dim = lfr_m*in_feat_dim;
|
||||
int32_t num_frames = asr_feats.size();
|
||||
|
||||
std::vector<float> wav_feats;
|
||||
for (const auto &frame_feat: asr_feats) {
|
||||
wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end());
|
||||
}
|
||||
std::vector<vector<float>> feats_batch;
|
||||
std::vector<int32_t> paraformer_length;
|
||||
paraformer_length.emplace_back(num_frames);
|
||||
int max_size = 0;
|
||||
int max_frames = 0;
|
||||
for(int index=0; index<batch_in; index++){
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
|
||||
if(asr_feats.size() != 0){
|
||||
LfrCmvn(asr_feats);
|
||||
}
|
||||
feats_batch.emplace_back(asr_feats);
|
||||
int32_t num_frames = asr_feats.size() / feature_dim;
|
||||
paraformer_length.emplace_back(num_frames);
|
||||
if(max_size < asr_feats.size()){
|
||||
max_size = asr_feats.size();
|
||||
max_frames = num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
torch::NoGradGuard no_grad;
|
||||
model_->eval();
|
||||
// padding
|
||||
std::vector<float> all_feats(batch_in * max_frames * feature_dim);
|
||||
for(int index=0; index<batch_in; index++){
|
||||
feats_batch[index].resize(max_size);
|
||||
std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
|
||||
max_frames * feature_dim * sizeof(float));
|
||||
}
|
||||
torch::Tensor feats =
|
||||
torch::from_blob(wav_feats.data(),
|
||||
{1, num_frames, feat_dim}, torch::kFloat).contiguous();
|
||||
torch::from_blob(all_feats.data(),
|
||||
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
|
||||
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
|
||||
{1}, torch::kInt32);
|
||||
{batch_in}, torch::kInt32);
|
||||
|
||||
// 2. forward
|
||||
#ifdef USE_GPU
|
||||
@ -301,7 +312,7 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
|
||||
#endif
|
||||
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
|
||||
|
||||
string result="";
|
||||
vector<std::string> results;
|
||||
try {
|
||||
auto outputs = model_->forward(inputs).toTuple()->elements();
|
||||
torch::Tensor am_scores;
|
||||
@ -314,47 +325,49 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
|
||||
valid_token_lens = outputs[1].toTensor();
|
||||
#endif
|
||||
// timestamp
|
||||
if(outputs.size() == 4){
|
||||
torch::Tensor us_alphas_tensor;
|
||||
torch::Tensor us_peaks_tensor;
|
||||
#ifdef USE_GPU
|
||||
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
|
||||
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
|
||||
#else
|
||||
us_alphas_tensor = outputs[2].toTensor();
|
||||
us_peaks_tensor = outputs[3].toTensor();
|
||||
#endif
|
||||
for(int index=0; index<batch_in; index++){
|
||||
string result="";
|
||||
if(outputs.size() == 4){
|
||||
torch::Tensor us_alphas_tensor;
|
||||
torch::Tensor us_peaks_tensor;
|
||||
#ifdef USE_GPU
|
||||
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
|
||||
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
|
||||
#else
|
||||
us_alphas_tensor = outputs[2].toTensor();
|
||||
us_peaks_tensor = outputs[3].toTensor();
|
||||
#endif
|
||||
|
||||
int us_alphas_shape_1 = us_alphas_tensor.size(1);
|
||||
float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
|
||||
std::vector<float> us_alphas(us_alphas_shape_1);
|
||||
for (int i = 0; i < us_alphas_shape_1; i++) {
|
||||
us_alphas[i] = us_alphas_data[i];
|
||||
}
|
||||
|
||||
int us_peaks_shape_1 = us_peaks_tensor.size(1);
|
||||
float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
|
||||
std::vector<float> us_peaks(us_peaks_shape_1);
|
||||
for (int i = 0; i < us_peaks_shape_1; i++) {
|
||||
us_peaks[i] = us_peaks_data[i];
|
||||
}
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
|
||||
float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
|
||||
std::vector<float> us_alphas(paraformer_length[index]);
|
||||
for (int i = 0; i < us_alphas.size(); i++) {
|
||||
us_alphas[i] = us_alphas_data[i];
|
||||
}
|
||||
}
|
||||
}else{
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder);
|
||||
|
||||
float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
|
||||
std::vector<float> us_peaks(paraformer_length[index]);
|
||||
for (int i = 0; i < us_peaks.size(); i++) {
|
||||
us_peaks[i] = us_peaks_data[i];
|
||||
}
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
|
||||
}
|
||||
}
|
||||
}else{
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder);
|
||||
}
|
||||
}
|
||||
}
|
||||
results.push_back(result);
|
||||
}
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
@ -362,7 +375,7 @@ string ParaformerTorch::Forward(float* din, int len, bool input_finished, const
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
|
||||
return result;
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
|
||||
|
||||
@ -48,13 +48,15 @@ namespace funasr {
|
||||
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
|
||||
void Reset();
|
||||
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
|
||||
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
|
||||
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
int GetAsrSampleRate() { return asr_sample_rate; };
|
||||
void SetBatchSize(int batch_size) {batch_size_ = batch_size};
|
||||
int GetBatchSize() {return batch_size_;};
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
|
||||
@ -88,6 +90,7 @@ namespace funasr {
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
int asr_sample_rate = MODEL_SAMPLE_RATE;
|
||||
int batch_size_ = 1;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -462,15 +462,23 @@ void Paraformer::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
|
||||
asr_feats = out_feats;
|
||||
}
|
||||
|
||||
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
|
||||
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
|
||||
{
|
||||
std::vector<std::string> results;
|
||||
string result="";
|
||||
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
|
||||
if(batch_in != 1){
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(asr_sample_rate, din, len, asr_feats);
|
||||
FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
|
||||
if(asr_feats.size() == 0){
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
LfrCmvn(asr_feats);
|
||||
int32_t feat_dim = lfr_m*in_feat_dim;
|
||||
@ -509,7 +517,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
if (use_hotword) {
|
||||
if(hw_emb.size()<=0){
|
||||
LOG(ERROR) << "hw_emb is null";
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
//PrintMat(hw_emb, "input_clas_emb");
|
||||
const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
|
||||
@ -526,10 +535,10 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
string result="";
|
||||
try {
|
||||
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
|
||||
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
@ -577,7 +586,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
|
||||
return result;
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -52,13 +52,14 @@ namespace funasr {
|
||||
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
|
||||
void Reset();
|
||||
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
|
||||
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
|
||||
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
int GetAsrSampleRate() { return asr_sample_rate; };
|
||||
int GetBatchSize() {return batch_size_;};
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
|
||||
@ -110,6 +111,7 @@ namespace funasr {
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
int asr_sample_rate = MODEL_SAMPLE_RATE;
|
||||
int batch_size_ = 1;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
Loading…
Reference in New Issue
Block a user