fix decoding

This commit is contained in:
jmwang66 2023-02-03 17:37:48 +08:00
parent bd450caec1
commit 998fcc6579

View File

@ -223,6 +223,31 @@ def inference_launch(**kwargs):
logging.info("Unknown decoding mode: {}".format(mode))
return None
def inference_launch_funasr(**kwargs):
if 'mode' in kwargs:
mode = kwargs['mode']
else:
logging.info("Unknown decoding mode.")
return None
if mode == "asr":
from funasr.bin.asr_inference import inference
return inference(**kwargs)
elif mode == "uniasr":
from funasr.bin.asr_inference_uniasr import inference
return inference(**kwargs)
elif mode == "paraformer":
from funasr.bin.asr_inference_paraformer import inference
return inference(**kwargs)
elif mode == "paraformer_vad_punc":
from funasr.bin.asr_inference_paraformer_vad_punc import inference
return inference(**kwargs)
elif mode == "vad":
from funasr.bin.vad_inference import inference
return inference(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
@ -251,7 +276,7 @@ def main(cmd=None):
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
inference_launch_funasr(**kwargs)
if __name__ == "__main__":