FunASR/funasr/bin/inference_cli.py
2023-11-17 15:19:53 +08:00

140 lines
3.1 KiB
Python

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
import logging
import torch
import numpy as np
from funasr.utils.download_and_prepare_model import prepare_model
from funasr.utils.types import str2bool
def infer(task_name: str = "asr",
model: str = None,
# mode: str = None,
vad_model: str = None,
disable_vad: bool = False,
punc_model: str = None,
disable_punc: bool = False,
model_hub: str = "ms",
cache_dir: str = None,
**kwargs,
):
# set logging messages
logging.basicConfig(
level=logging.ERROR,
)
model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
if task_name == "asr":
from funasr.bin.asr_inference_launch import inference_launch
inference_pipeline = inference_launch(**kwargs)
elif task_name == "":
pipeline = 1
elif task_name == "":
pipeline = 2
elif task_name == "":
pipeline = 2
def _infer_fn(input, **kwargs):
data_type = kwargs.get('data_type', 'sound')
data_path_and_name_and_type = [input, 'speech', data_type]
raw_inputs = None
if isinstance(input, torch.Tensor):
input = input.numpy()
if isinstance(input, np.ndarray):
data_path_and_name_and_type = None
raw_inputs = input
return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
return _infer_fn
def main(cmd=None):
# print(get_commandline_args(), file=sys.stderr)
from funasr.bin.argument import get_parser
parser = get_parser()
parser.add_argument('input', help='input file to transcribe')
parser.add_argument(
"--task_name",
type=str,
default="asr",
help="The decoding mode",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="paraformer-zh",
help="The asr mode name",
)
parser.add_argument(
"-v",
"--vad_model",
type=str,
default="fsmn-vad",
help="vad model name",
)
parser.add_argument(
"-dv",
"--disable_vad",
type=str2bool,
default=False,
help="",
)
parser.add_argument(
"-p",
"--punc_model",
type=str,
default="ct-punc",
help="",
)
parser.add_argument(
"-dp",
"--disable_punc",
type=str2bool,
default=False,
help="",
)
parser.add_argument(
"--batch_size_token",
type=int,
default=5000,
help="",
)
parser.add_argument(
"--batch_size_token_threshold_s",
type=int,
default=35,
help="",
)
parser.add_argument(
"--max_single_segment_time",
type=int,
default=5000,
help="",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
# set logging messages
logging.basicConfig(
level=logging.ERROR,
)
logging.info("Decoding args: {}".format(kwargs))
# kwargs["ncpu"] = 2 #os.cpu_count()
kwargs.pop("data_path_and_name_and_type")
print("args: {}".format(kwargs))
p = infer(**kwargs)
res = p(**kwargs)
print(res)