add warmup for paraformer-torch

This commit is contained in:
雾聪 2024-06-26 11:39:19 +08:00
parent b7060884fa
commit 38c1f6393a
2 changed files with 49 additions and 0 deletions

View File

@ -55,6 +55,9 @@ void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am
torch::jit::setGraphExecutorOptimize(false);
torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}};
torch::jit::setFusionStrategy(static0);
#ifdef USE_GPU
WarmUp();
#endif
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am model: " << am_model << e.what();
exit(-1);
@ -471,6 +474,51 @@ std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool in
return results;
}
void ParaformerTorch::WarmUp()
{
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t feature_dim = lfr_m*in_feat_dim;
int batch_in = 1;
int max_frames = 10;
std::vector<int32_t> paraformer_length;
paraformer_length.push_back(max_frames);
std::vector<float> all_feats(batch_in * max_frames * feature_dim, 0.1);
torch::Tensor feats =
torch::from_blob(all_feats.data(),
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
{batch_in}, torch::kInt32);
// 2. forward
feats = feats.to(at::kCUDA);
feat_lens = feat_lens.to(at::kCUDA);
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
if (use_hotword) {
std::string hotwords_wp = "";
std::vector<std::vector<float>> hw_emb = CompileHotwordEmbedding(hotwords_wp);
std::vector<float> embedding;
embedding.reserve(hw_emb.size() * hw_emb[0].size());
for (auto item : hw_emb) {
embedding.insert(embedding.end(), item.begin(), item.end());
}
torch::Tensor tensor_hw_emb =
torch::from_blob(embedding.data(),
{batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous();
tensor_hw_emb = tensor_hw_emb.to(at::kCUDA);
inputs.emplace_back(tensor_hw_emb);
}
try {
auto outputs = model_->forward(inputs).toTuple()->elements();
}
catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
}
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
int embedding_dim = encoder_size;
std::vector<std::vector<float>> hw_emb;

View File

@ -49,6 +49,7 @@ namespace funasr {
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
void WarmUp();
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});