fix func FunASRWfstDecoderInit

This commit is contained in:
雾聪 2024-03-21 15:38:59 +08:00
parent 733930b2aa
commit 462355c002
2 changed files with 37 additions and 10 deletions

View File

@ -31,8 +31,6 @@ class Model {
virtual Vocab* GetVocab() {return nullptr;};
virtual Vocab* GetLmVocab() {return nullptr;};
virtual PhoneSet* GetPhoneSet() {return nullptr;};
std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);

View File

@ -767,16 +767,45 @@
funasr::WfstDecoder* mm = nullptr;
if (asr_type == ASR_OFFLINE) {
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
funasr::Model* paraformer = offline_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
if(paraformer !=nullptr){
if (paraformer->lm_){
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#ifdef USE_GPU
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
if(paraformer_torch !=nullptr){
if (paraformer_torch->lm_){
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#endif
} else if (asr_type == ASR_TWO_PASS){
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::Model* paraformer = tpass_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
if(paraformer !=nullptr){
if (paraformer->lm_){
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#ifdef USE_GPU
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
if(paraformer_torch !=nullptr){
if (paraformer_torch->lm_){
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#endif
}
return mm;
}