diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index c1f08642a..a50e03869 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -95,10 +95,13 @@ class Speech2Text: logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) token_list = asr_model.token_list scorers.update( - ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py index 3fb87643b..b6469871a 100644 --- a/funasr/bin/asr_inference_paraformer_timestamp.py +++ b/funasr/bin/asr_inference_paraformer_timestamp.py @@ -98,10 +98,13 @@ class Speech2Text: logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) token_list = asr_model.token_list scorers.update( - ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 4b5b31682..85838aa79 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -100,10 +100,13 @@ class Speech2Text: # logging.info("asr_train_args: {}".format(asr_train_args)) asr_model.to(dtype=getattr(torch, dtype)).eval() - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) token_list = asr_model.token_list scorers.update( - ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) @@ -663,7 +666,7 @@ def inference_modelscope( time_stamp_postprocessed)) logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}". - format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6))) + format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))) return asr_result_list return _forward diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index 515c0d4c1..d386ff13f 100644 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -96,11 +96,14 @@ class Speech2Text: else: decoder = asr_model.decoder2 - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) token_list = asr_model.token_list scorers.update( decoder=decoder, - ctc=ctc, length_bonus=LengthBonus(len(token_list)), )