diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h index cff617f38..ba3cbf47a 100644 --- a/runtime/onnxruntime/include/funasrruntime.h +++ b/runtime/onnxruntime/include/funasrruntime.h @@ -96,7 +96,7 @@ _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result); _FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle); //OfflineStream -_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num); +_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num, bool use_gpu=false); _FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr); // buffer _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, diff --git a/runtime/onnxruntime/include/offline-stream.h b/runtime/onnxruntime/include/offline-stream.h index f63de746f..0bec79771 100644 --- a/runtime/onnxruntime/include/offline-stream.h +++ b/runtime/onnxruntime/include/offline-stream.h @@ -14,7 +14,7 @@ namespace funasr { class OfflineStream { public: - OfflineStream(std::map& model_path, int thread_num); + OfflineStream(std::map& model_path, int thread_num, bool use_gpu=false); ~OfflineStream(){}; std::unique_ptr vad_handle= nullptr; @@ -33,6 +33,6 @@ class OfflineStream { bool use_itn=false; }; -OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num=1); +OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num=1, bool use_gpu=false); } // namespace funasr #endif diff --git a/runtime/onnxruntime/src/CMakeLists.txt b/runtime/onnxruntime/src/CMakeLists.txt index 9eac2b616..d6c8a205e 100644 --- a/runtime/onnxruntime/src/CMakeLists.txt +++ b/runtime/onnxruntime/src/CMakeLists.txt @@ -25,7 +25,11 @@ else() include_directories(${FFMPEG_DIR}/include) endif() +if(GPU) + set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context) +endif() + #message("CXX_FLAGS "${CMAKE_CXX_FLAGS}) include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/third_party) -target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS}) +target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS}) diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp index 4bc64aff6..d795cb0de 100644 --- a/runtime/onnxruntime/src/funasrruntime.cpp +++ b/runtime/onnxruntime/src/funasrruntime.cpp @@ -33,9 +33,9 @@ return mm; } - _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num) + _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map& model_path, int thread_num, bool use_gpu) { - funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num); + funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu); return mm; } diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp index ae8cf184f..9cdcdd2ab 100644 --- a/runtime/onnxruntime/src/offline-stream.cpp +++ b/runtime/onnxruntime/src/offline-stream.cpp @@ -1,7 +1,7 @@ #include "precomp.h" namespace funasr { -OfflineStream::OfflineStream(std::map& model_path, int thread_num) +OfflineStream::OfflineStream(std::map& model_path, int thread_num, bool use_gpu) { // VAD model if(model_path.find(VAD_DIR) != model_path.end()){ @@ -35,7 +35,12 @@ OfflineStream::OfflineStream(std::map& model_path, int string hw_compile_model_path; string seg_dict_path; - asr_handle = make_unique(); + if(use_gpu){ + asr_handle = make_unique(); + }else{ + asr_handle = make_unique(); + } + bool enable_hotword = false; hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME); seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT); @@ -115,10 +120,10 @@ OfflineStream::OfflineStream(std::map& model_path, int #endif } -OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num) +OfflineStream *CreateOfflineStream(std::map& model_path, int thread_num, bool use_gpu) { OfflineStream *mm; - mm = new OfflineStream(model_path, thread_num); + mm = new OfflineStream(model_path, thread_num, use_gpu); return mm; } diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp new file mode 100644 index 000000000..1f15ec7e8 --- /dev/null +++ b/runtime/onnxruntime/src/paraformer-torch.cpp @@ -0,0 +1,351 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ + +#include "precomp.h" +#include "paraformer-torch.h" +#include "encode_converter.h" +#include + +using namespace std; +namespace funasr { + +ParaformerTorch::ParaformerTorch() +:use_hotword(false){ +} + +// offline +void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){ + LoadConfigFromYaml(am_config.c_str()); + // knf options + fbank_opts_.frame_opts.dither = 0; + fbank_opts_.mel_opts.num_bins = n_mels; + fbank_opts_.frame_opts.samp_freq = asr_sample_rate; + fbank_opts_.frame_opts.window_type = window_type; + fbank_opts_.frame_opts.frame_shift_ms = frame_shift; + fbank_opts_.frame_opts.frame_length_ms = frame_length; + fbank_opts_.energy_floor = 0; + fbank_opts_.mel_opts.debug_mel = false; + + vocab = new Vocab(am_config.c_str()); + phone_set_ = new PhoneSet(am_config.c_str()); + LoadCmvn(am_cmvn.c_str()); + + torch::DeviceType device = at::kCPU; + #ifdef USE_GPU + if (!torch::cuda::is_available()) { + LOG(ERROR) << "CUDA is not available! Please check your GPU settings"; + exit(-1); + } else { + LOG(INFO) << "CUDA available! Running on GPU"; + device = at::kCUDA; + } + #endif + #ifdef USE_IPEX + torch::jit::setTensorExprFuserEnabled(false); + #endif + torch::jit::script::Module model = torch::jit::load(am_model, device); + model_ = std::make_shared(std::move(model)); +} + +void ParaformerTorch::InitLm(const std::string &lm_file, + const std::string &lm_cfg_file, + const std::string &lex_file) { + try { + lm_ = std::shared_ptr>( + fst::Fst::Read(lm_file)); + if (lm_){ + lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str()); + LOG(INFO) << "Successfully load lm file " << lm_file; + }else{ + LOG(ERROR) << "Failed to load lm file " << lm_file; + } + } catch (std::exception const &e) { + LOG(ERROR) << "Error when load lm file: " << e.what(); + exit(0); + } +} + +void ParaformerTorch::LoadConfigFromYaml(const char* filename){ + + YAML::Node config; + try{ + config = YAML::LoadFile(filename); + }catch(exception const &e){ + LOG(ERROR) << "Error loading file, yaml file error or not exist."; + exit(-1); + } + + try{ + YAML::Node frontend_conf = config["frontend_conf"]; + this->asr_sample_rate = frontend_conf["fs"].as(); + + YAML::Node lang_conf = config["lang"]; + if (lang_conf.IsDefined()){ + language = lang_conf.as(); + } + }catch(exception const &e){ + LOG(ERROR) << "Error when load argument from vad config YAML."; + exit(-1); + } +} + +void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) { + // TODO + use_hotword = true; +} + +void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) { + seg_dict = new SegDict(seg_dict_model.c_str()); +} + +ParaformerTorch::~ParaformerTorch() +{ + if(vocab){ + delete vocab; + } + if(lm_vocab){ + delete lm_vocab; + } + if(seg_dict){ + delete seg_dict; + } + if(phone_set_){ + delete phone_set_; + } +} + +void ParaformerTorch::StartUtterance() +{ +} + +void ParaformerTorch::EndUtterance() +{ +} + +void ParaformerTorch::Reset() +{ +} + +void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector> &asr_feats) { + knf::OnlineFbank fbank_(fbank_opts_); + std::vector buf(len); + for (int32_t i = 0; i != len; ++i) { + buf[i] = waves[i] * 32768; + } + fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size()); + + int32_t frames = fbank_.NumFramesReady(); + for (int32_t i = 0; i != frames; ++i) { + const float *frame = fbank_.GetFrame(i); + std::vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); + asr_feats.emplace_back(frame_vector); + } +} + +void ParaformerTorch::LoadCmvn(const char *filename) +{ + ifstream cmvn_stream(filename); + if (!cmvn_stream.is_open()) { + LOG(ERROR) << "Failed to open file: " << filename; + exit(-1); + } + string line; + + while (getline(cmvn_stream, line)) { + istringstream iss(line); + vector line_item{istream_iterator{iss}, istream_iterator{}}; + if (line_item[0] == "") { + getline(cmvn_stream, line); + istringstream means_lines_stream(line); + vector means_lines{istream_iterator{means_lines_stream}, istream_iterator{}}; + if (means_lines[0] == "") { + for (int j = 3; j < means_lines.size() - 1; j++) { + means_list_.push_back(stof(means_lines[j])); + } + continue; + } + } + else if (line_item[0] == "") { + getline(cmvn_stream, line); + istringstream vars_lines_stream(line); + vector vars_lines{istream_iterator{vars_lines_stream}, istream_iterator{}}; + if (vars_lines[0] == "") { + for (int j = 3; j < vars_lines.size() - 1; j++) { + vars_list_.push_back(stof(vars_lines[j])*scale); + } + continue; + } + } + } +} + +string ParaformerTorch::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector us_alphas, std::vector us_cif_peak) +{ + vector hyps; + int Tmax = n_len; + for (int i = 0; i < Tmax; i++) { + int max_idx; + float max_val; + FindMax(in + i * token_nums, token_nums, max_val, max_idx); + hyps.push_back(max_idx); + } + if(!is_stamp){ + return vocab->Vector2StringV2(hyps, language); + }else{ + std::vector char_list; + std::vector> timestamp_list; + std::string res_str; + vocab->Vector2String(hyps, char_list); + std::vector raw_char(char_list); + TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list); + + return PostProcess(raw_char, timestamp_list); + } +} + +string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums) +{ + return wfst_decoder->Search(in, len, token_nums); +} + +string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder, + bool is_stamp, std::vector us_alphas, std::vector us_cif_peak) +{ + return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak); +} + +void ParaformerTorch::LfrCmvn(std::vector> &asr_feats) { + + std::vector> out_feats; + int T = asr_feats.size(); + int T_lrf = ceil(1.0 * T / lfr_n); + + // Pad frames at start(copy first frame) + for (int i = 0; i < (lfr_m - 1) / 2; i++) { + asr_feats.insert(asr_feats.begin(), asr_feats[0]); + } + // Merge lfr_m frames as one,lfr_n frames per window + T = T + (lfr_m - 1) / 2; + std::vector p; + for (int i = 0; i < T_lrf; i++) { + if (lfr_m <= T - i * lfr_n) { + for (int j = 0; j < lfr_m; j++) { + p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end()); + } + out_feats.emplace_back(p); + p.clear(); + } else { + // Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame) + int num_padding = lfr_m - (T - i * lfr_n); + for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) { + p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end()); + } + for (int j = 0; j < num_padding; j++) { + p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end()); + } + out_feats.emplace_back(p); + p.clear(); + } + } + // Apply cmvn + for (auto &out_feat: out_feats) { + for (int j = 0; j < means_list_.size(); j++) { + out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j]; + } + } + asr_feats = out_feats; +} + +string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector> &hw_emb, void* decoder_handle) +{ + WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle; + int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins; + + std::vector> asr_feats; + FbankKaldi(asr_sample_rate, din, len, asr_feats); + if(asr_feats.size() == 0){ + return ""; + } + LfrCmvn(asr_feats); + int32_t feat_dim = lfr_m*in_feat_dim; + int32_t num_frames = asr_feats.size(); + + std::vector wav_feats; + for (const auto &frame_feat: asr_feats) { + wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end()); + } + std::vector paraformer_length; + paraformer_length.emplace_back(num_frames); + + torch::NoGradGuard no_grad; + torch::Tensor feats = + torch::from_blob(wav_feats.data(), + {1, num_frames, feat_dim}, torch::kFloat).contiguous(); + torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(), + {1}, torch::kInt32); + + // 2. forward + #ifdef USE_GPU + feats = feats.to(at::kCUDA); + feat_lens = feat_lens.to(at::kCUDA); + #endif + std::vector inputs = {feats, feat_lens}; + + string result=""; + try { + auto outputs = model_->forward(inputs).toTuple()->elements(); + torch::Tensor am_scores; + torch::Tensor valid_token_lens; + #ifdef USE_GPU + am_scores = outputs[0].toTensor().to(at::kCPU); + valid_token_lens = outputs[1].toTensor().to(at::kCPU); + #else + am_scores = outputs[0].toTensor(); + valid_token_lens = outputs[1].toTensor(); + #endif + + if (lm_ == nullptr) { + result = GreedySearch(am_scores[0].data_ptr(), valid_token_lens[0].item(), am_scores.size(2)); + } else { + result = BeamSearch(wfst_decoder, am_scores[0].data_ptr(), valid_token_lens[0].item(), am_scores.size(2)); + if (input_finished) { + result = FinalizeDecode(wfst_decoder); + } + } + } + catch (std::exception const &e) + { + LOG(ERROR)<> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) { + std::vector> result; + return result; +} + +Vocab* ParaformerTorch::GetVocab() +{ + return vocab; +} + +Vocab* ParaformerTorch::GetLmVocab() +{ + return lm_vocab; +} + +PhoneSet* ParaformerTorch::GetPhoneSet() +{ + return phone_set_; +} + +string ParaformerTorch::Rescoring() +{ + LOG(ERROR)<<"Not Imp!!!!!!"; + return ""; +} +} // namespace funasr diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h new file mode 100644 index 000000000..a5993de24 --- /dev/null +++ b/runtime/onnxruntime/src/paraformer-torch.h @@ -0,0 +1,92 @@ +/** + * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. + * MIT License (https://opensource.org/licenses/MIT) +*/ +#pragma once +#include +#include +#include +#include +#include "precomp.h" +#include "fst/fstlib.h" +#include "fst/symbol-table.h" +#include "bias-lm.h" +#include "phone-set.h" + +namespace funasr { + + class ParaformerTorch : public Model { + /** + * Author: Speech Lab of DAMO Academy, Alibaba Group + * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + * https://arxiv.org/pdf/2206.08317.pdf + */ + private: + Vocab* vocab = nullptr; + Vocab* lm_vocab = nullptr; + SegDict* seg_dict = nullptr; + PhoneSet* phone_set_ = nullptr; + //const float scale = 22.6274169979695; + const float scale = 1.0; + + void LoadConfigFromYaml(const char* filename); + void LoadCmvn(const char *filename); + void LfrCmvn(std::vector> &asr_feats); + + using TorchModule = torch::jit::script::Module; + std::shared_ptr model_ = nullptr; + std::vector encoder_outs_; + bool use_hotword; + + public: + ParaformerTorch(); + ~ParaformerTorch(); + void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num); + void InitHwCompiler(const std::string &hw_model, int thread_num); + void InitSegDict(const std::string &seg_dict_model); + std::vector> CompileHotwordEmbedding(std::string &hotwords); + void Reset(); + void FbankKaldi(float sample_rate, const float* waves, int len, std::vector> &asr_feats); + string Forward(float* din, int len, bool input_finished=true, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr); + string GreedySearch( float* in, int n_len, int64_t token_nums, + bool is_stamp=false, std::vector us_alphas={0}, std::vector us_cif_peak={0}); + + string Rescoring(); + string GetLang(){return language;}; + int GetAsrSampleRate() { return asr_sample_rate; }; + void StartUtterance(); + void EndUtterance(); + void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file); + string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums); + string FinalizeDecode(WfstDecoder* &wfst_decoder, + bool is_stamp=false, std::vector us_alphas={0}, std::vector us_cif_peak={0}); + Vocab* GetVocab(); + Vocab* GetLmVocab(); + PhoneSet* GetPhoneSet(); + + knf::FbankOptions fbank_opts_; + vector means_list_; + vector vars_list_; + int lfr_m = PARA_LFR_M; + int lfr_n = PARA_LFR_N; + + // paraformer-offline + std::string language="zh-cn"; + + // lm + std::shared_ptr> lm_ = nullptr; + + string window_type = "hamming"; + int frame_length = 25; + int frame_shift = 10; + int n_mels = 80; + int encoder_size = 512; + int fsmn_layers = 16; + int fsmn_lorder = 10; + int fsmn_dims = 512; + float cif_threshold = 1.0; + float tail_alphas = 0.45; + int asr_sample_rate = MODEL_SAMPLE_RATE; + }; + +} // namespace funasr diff --git a/runtime/onnxruntime/src/precomp.h b/runtime/onnxruntime/src/precomp.h index 776de8eeb..5513819a6 100644 --- a/runtime/onnxruntime/src/precomp.h +++ b/runtime/onnxruntime/src/precomp.h @@ -64,6 +64,7 @@ using namespace std; #include "seg_dict.h" #include "resample.h" #include "paraformer.h" +#include "paraformer-torch.h" #include "paraformer-online.h" #include "offline-stream.h" #include "tpass-stream.h"