FunASR/runtime/onnxruntime/src/paraformer.h
zhifu gao 3b0526e7be
update with main (#1783)
* add cmakelist

* add paraformer-torch

* add debug for funasr-onnx-offline

* fix redefinition of jieba StdExtension.hpp

* add loading torch models

* update funasr-onnx-offline

* add SwitchArg for wss-server

* add SwitchArg for funasr-onnx-offline

* update cmakelist

* update funasr-onnx-offline-rtf

* add define condition

* add gpu define for offlne-stream

* update com define

* update offline-stream

* update cmakelist

* update func CompileHotwordEmbedding

* add timestamp for paraformer-torch

* add C10_USE_GLOG for paraformer-torch

* update paraformer-torch

* fix func FunASRWfstDecoderInit

* update model.h

* fix func FunASRWfstDecoderInit

* fix tpass_stream

* update paraformer-torch

* add bladedisc for funasr-onnx-offline

* update comdefine

* update funasr-wss-server

* add log for torch

* fix GetValue BLADEDISC

* fix log

* update cmakelist

* update warmup to 10

* update funasrruntime

* add batch_size for wss-server

* add batch for bins

* add batch for offline-stream

* add batch for paraformer

* add batch for offline-stream

* fix func SetBatchSize

* add SetBatchSize for model

* add SetBatchSize for model

* fix func Forward

* fix padding

* update funasrruntime

* add dec reset for batch

* set batch default value

* add argv for CutSplit

* sort frame_queue

* sorted msgs

* fix FunOfflineInfer

* add dynamic batch for fetch

* fix FetchDynamic

* update run_server.sh

* update run_server.sh

* cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments

* del debug infos

* restore run_server.sh

* adapt to new model struct

* 修复了onnxruntime在macos下编译失败的错误 (#1748)

* Add files via upload

增加macos的编译支持

* Add files via upload

增加macos支持

* Add files via upload

target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
添加 if(APPLE) 限制

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>

* Delete docs/images/wechat.png

* Add files via upload

* fixed the issues about seaco-onnx timestamp

* fix bug (#1764)

当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url

* fix empty asr result (#1765)

解码结果为空的语音片段,text 用空字符串

* docs

* docs

* docs

* docs

* docs

* keep empty speech result (#1772)

* docs

* docs

* update wechat QRcode

* Add python funasr api support for websocket srv (#1777)

* add python funasr_api supoort

* change little to README.md

* add core tools stream

* modified a little

* fix bug for timeout

* support for buffer decode

* add ffmpeg decode for buffer

* auto frontend

* auto frontend

---------

Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com>
Co-authored-by: Ephemeroptera <605686962@qq.com>
Co-authored-by: 彭震东 <zhendong.peng@qq.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
2024-06-04 11:21:36 +08:00

118 lines
4.8 KiB
C++

/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#pragma once
#include "precomp.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "bias-lm.h"
#include "phone-set.h"
namespace funasr {
class Paraformer : public Model {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
* https://arxiv.org/pdf/2206.08317.pdf
*/
private:
Vocab* vocab = nullptr;
Vocab* lm_vocab = nullptr;
SegDict* seg_dict = nullptr;
PhoneSet* phone_set_ = nullptr;
//const float scale = 22.6274169979695;
const float scale = 1.0;
void LoadConfigFromYaml(const char* filename);
void LoadOnlineConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
std::shared_ptr<Ort::Session> hw_m_session = nullptr;
Ort::Env hw_env_;
Ort::SessionOptions hw_session_options;
vector<string> hw_m_strInputNames, hw_m_strOutputNames;
vector<const char*> hw_m_szInputNames;
vector<const char*> hw_m_szOutputNames;
bool use_hotword;
public:
Paraformer();
~Paraformer();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// online
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
void InitHwCompiler(const std::string &hw_model, int thread_num);
void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
string FinalizeDecode(WfstDecoder* &wfst_decoder,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
Vocab* GetVocab();
Vocab* GetLmVocab();
PhoneSet* GetPhoneSet();
knf::FbankOptions fbank_opts_;
vector<float> means_list_;
vector<float> vars_list_;
int lfr_m = PARA_LFR_M;
int lfr_n = PARA_LFR_N;
// paraformer-offline
std::shared_ptr<Ort::Session> m_session_ = nullptr;
Ort::Env env_;
Ort::SessionOptions session_options_;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
std::string language="zh-cn";
// paraformer-online
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
vector<string> en_strInputNames, en_strOutputNames;
vector<const char*> en_szInputNames_;
vector<const char*> en_szOutputNames_;
vector<string> de_strInputNames, de_strOutputNames;
vector<const char*> de_szInputNames_;
vector<const char*> de_szOutputNames_;
// lm
std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
string window_type = "hamming";
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
int encoder_size = 512;
int fsmn_layers = 16;
int fsmn_lorder = 10;
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
};
} // namespace funasr