mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add warmup for paraformer-torch
This commit is contained in:
parent
b7060884fa
commit
38c1f6393a
@ -55,6 +55,9 @@ void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am
|
||||
torch::jit::setGraphExecutorOptimize(false);
|
||||
torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}};
|
||||
torch::jit::setFusionStrategy(static0);
|
||||
#ifdef USE_GPU
|
||||
WarmUp();
|
||||
#endif
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am model: " << am_model << e.what();
|
||||
exit(-1);
|
||||
@ -471,6 +474,51 @@ std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool in
|
||||
return results;
|
||||
}
|
||||
|
||||
void ParaformerTorch::WarmUp()
|
||||
{
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
int32_t feature_dim = lfr_m*in_feat_dim;
|
||||
int batch_in = 1;
|
||||
int max_frames = 10;
|
||||
std::vector<int32_t> paraformer_length;
|
||||
paraformer_length.push_back(max_frames);
|
||||
|
||||
std::vector<float> all_feats(batch_in * max_frames * feature_dim, 0.1);
|
||||
torch::Tensor feats =
|
||||
torch::from_blob(all_feats.data(),
|
||||
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
|
||||
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
|
||||
{batch_in}, torch::kInt32);
|
||||
|
||||
// 2. forward
|
||||
feats = feats.to(at::kCUDA);
|
||||
feat_lens = feat_lens.to(at::kCUDA);
|
||||
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
|
||||
|
||||
if (use_hotword) {
|
||||
std::string hotwords_wp = "";
|
||||
std::vector<std::vector<float>> hw_emb = CompileHotwordEmbedding(hotwords_wp);
|
||||
std::vector<float> embedding;
|
||||
embedding.reserve(hw_emb.size() * hw_emb[0].size());
|
||||
for (auto item : hw_emb) {
|
||||
embedding.insert(embedding.end(), item.begin(), item.end());
|
||||
}
|
||||
torch::Tensor tensor_hw_emb =
|
||||
torch::from_blob(embedding.data(),
|
||||
{batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous();
|
||||
tensor_hw_emb = tensor_hw_emb.to(at::kCUDA);
|
||||
inputs.emplace_back(tensor_hw_emb);
|
||||
}
|
||||
|
||||
try {
|
||||
auto outputs = model_->forward(inputs).toTuple()->elements();
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
|
||||
int embedding_dim = encoder_size;
|
||||
std::vector<std::vector<float>> hw_emb;
|
||||
|
||||
@ -49,6 +49,7 @@ namespace funasr {
|
||||
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
|
||||
void Reset();
|
||||
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
|
||||
void WarmUp();
|
||||
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user