diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 67a85d242..d72fd4b5b 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -223,6 +223,31 @@ def inference_launch(**kwargs): logging.info("Unknown decoding mode: {}".format(mode)) return None +def inference_launch_funasr(**kwargs): + if 'mode' in kwargs: + mode = kwargs['mode'] + else: + logging.info("Unknown decoding mode.") + return None + if mode == "asr": + from funasr.bin.asr_inference import inference + return inference(**kwargs) + elif mode == "uniasr": + from funasr.bin.asr_inference_uniasr import inference + return inference(**kwargs) + elif mode == "paraformer": + from funasr.bin.asr_inference_paraformer import inference + return inference(**kwargs) + elif mode == "paraformer_vad_punc": + from funasr.bin.asr_inference_paraformer_vad_punc import inference + return inference(**kwargs) + elif mode == "vad": + from funasr.bin.vad_inference import inference + return inference(**kwargs) + else: + logging.info("Unknown decoding mode: {}".format(mode)) + return None + def main(cmd=None): print(get_commandline_args(), file=sys.stderr) @@ -251,7 +276,7 @@ def main(cmd=None): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpuid - inference_launch(**kwargs) + inference_launch_funasr(**kwargs) if __name__ == "__main__":