From 244c033fbaeae15faf8b0351365bdb7607b2e2bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 17 Nov 2023 15:19:53 +0800 Subject: [PATCH] python cli --- README.md | 9 + README_zh.md | 9 + funasr/__init__.py | 127 +--------- funasr/bin/argument.py | 262 +++++++++++++++++++++ funasr/bin/asr_inference_launch.py | 258 +------------------- funasr/bin/inference_cli.py | 139 +++++++++++ funasr/utils/download_and_prepare_model.py | 93 ++++++++ funasr/version.txt | 2 +- setup.py | 3 + 9 files changed, 522 insertions(+), 380 deletions(-) create mode 100644 funasr/bin/argument.py create mode 100644 funasr/bin/inference_cli.py create mode 100644 funasr/utils/download_and_prepare_model.py diff --git a/README.md b/README.md index f73c0ca92..001ce3f1a 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,15 @@ Quick start for new users([tutorial](https://alibaba-damo-academy.github.io/Fu FunASR supports inference and fine-tuning of models trained on industrial data for tens of thousands of hours. For more details, please refer to [modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html). It also supports training and fine-tuning of models on academic standard datasets. For more information, please refer to [egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html). Below is a quick start tutorial. Test audio files ([Mandarin](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav), [English]()). + +### Command-line usage + +```shell +funasr --model paraformer-zh asr_example_zh.wav +``` + +Notes: Support recognition of single audio file, as well as file list in Kaldi-style wav.scp format: `wav_id wav_pat` + ### Speech Recognition (Non-streaming) ```python from funasr import infer diff --git a/README_zh.md b/README_zh.md index 554c0b61d..504c7156c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -70,6 +70,15 @@ FunASR开源了大量在工业数据上预训练模型,您可以在[模型许 FunASR支持数万小时工业数据训练的模型的推理和微调,详细信息可以参阅([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html));也支持学术标准数据集模型的训练和微调,详细信息可以参阅([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html))。 下面为快速上手教程,测试音频([中文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav),[英文]()) + +### 可执行命令行 + +```shell +funasr --model paraformer-zh asr_example_zh.wav +``` + +注:支持单条音频文件识别,也支持文件列表,列表为kaldi风格wav.scp:`wav_id wav_path` + ### 非实时语音识别 ```python from funasr import infer diff --git a/funasr/__init__.py b/funasr/__init__.py index aab42891d..d0b7aa5b6 100644 --- a/funasr/__init__.py +++ b/funasr/__init__.py @@ -1,135 +1,10 @@ """Initialize funasr package.""" import os -from pathlib import Path -import torch -import numpy as np dirname = os.path.dirname(__file__) version_file = os.path.join(dirname, "version.txt") with open(version_file, "r") as f: __version__ = f.read().strip() - -def prepare_model( - model: str = None, - # mode: str = None, - vad_model: str = None, - punc_model: str = None, - model_hub: str = "ms", - cache_dir: str = None, - **kwargs, -): - if not Path(model).exists(): - if model_hub == "ms" or model_hub == "modelscope": - try: - from modelscope.hub.snapshot_download import snapshot_download as download_tool - model = name_maps_ms[model] if model is not None else None - vad_model = name_maps_ms[vad_model] if vad_model is not None else None - punc_model = name_maps_ms[punc_model] if punc_model is not None else None - except: - raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \ - "\npip3 install -U modelscope\n" \ - "For the users in China, you could install with the command:\n" \ - "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple" - elif model_hub == "hf" or model_hub == "huggingface": - download_tool = 0 - else: - raise "model_hub must be on of ms or hf, but get {}".format(model_hub) - try: - model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None)) - print("model have been downloaded to: {}".format(model)) - except: - raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( - model) - - if vad_model is not None and not Path(vad_model).exists(): - vad_model = download_tool(vad_model, cache_dir=cache_dir) - print("model have been downloaded to: {}".format(vad_model)) - if punc_model is not None and not Path(punc_model).exists(): - punc_model = download_tool(punc_model, cache_dir=cache_dir) - print("model have been downloaded to: {}".format(punc_model)) - - # asr - kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"), - "asr_model_file": None if model is None else os.path.join(model, "model.pb"), - "asr_train_config": None if model is None else os.path.join(model, "config.yaml"), - }) - mode = kwargs.get("mode", None) - if mode is None: - import json - json_file = os.path.join(model, 'configuration.json') - with open(json_file, 'r') as f: - config_data = json.load(f) - if config_data['task'] == "punctuation": - mode = config_data['model']['punc_model_config']['mode'] - else: - mode = config_data['model']['model_config']['mode'] - if vad_model is not None and "vad" not in mode: - mode = "paraformer_vad" - kwargs["mode"] = mode - # vad - kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"), - "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"), - "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"), - }) - # punc - kwargs.update({ - "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"), - "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"), - }) - - - return model, vad_model, punc_model, kwargs - -name_maps_ms = { - "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", - "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", - "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", - "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", - "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", - "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", -} - -def infer(task_name: str = "asr", - model: str = None, - # mode: str = None, - vad_model: str = None, - punc_model: str = None, - model_hub: str = "ms", - cache_dir: str = None, - **kwargs, - ): - - model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs) - if task_name == "asr": - from funasr.bin.asr_inference_launch import inference_launch - - inference_pipeline = inference_launch(**kwargs) - elif task_name == "": - pipeline = 1 - elif task_name == "": - pipeline = 2 - elif task_name == "": - pipeline = 2 - - def _infer_fn(input, **kwargs): - data_type = kwargs.get('data_type', 'sound') - data_path_and_name_and_type = [input, 'speech', data_type] - raw_inputs = None - if isinstance(input, torch.Tensor): - input = input.numpy() - if isinstance(input, np.ndarray): - data_path_and_name_and_type = None - raw_inputs = input - - - - return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs) - - return _infer_fn - -if __name__ == '__main__': - pass \ No newline at end of file +from funasr.bin.inference_cli import infer \ No newline at end of file diff --git a/funasr/bin/argument.py b/funasr/bin/argument.py new file mode 100644 index 000000000..0ea4ac91c --- /dev/null +++ b/funasr/bin/argument.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import sys + +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import config_argparse +import argparse + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument( + "--ngpu", + type=int, + default=1, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--njob", + type=int, + default=1, + help="The number of jobs for each gpu", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + parser.add_argument( + "--hotword", + type=str_or_none, + default=None, + help="hotword file path or hotwords seperated by space" + ) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + group.add_argument( + "--mc", + type=bool, + default=False, + help="MultiChannel input", + ) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--vad_infer_config", + type=str, + help="VAD infer configuration", + ) + group.add_argument( + "--vad_model_file", + type=str, + help="VAD model parameter file", + ) + group.add_argument( + "--punc_infer_config", + type=str, + help="PUNC infer configuration", + ) + group.add_argument( + "--punc_model_file", + type=str, + help="PUNC model parameter file", + ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global CMVN file", + ) + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--sv_model_file", + type=str, + help="SV model parameter file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + group.add_argument( + "--beam_search_config", + default={}, + help="The keyword arguments for transducer beam search.", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.0, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + group.add_argument("--fake_streaming", type=str2bool, default=False) + group.add_argument("--full_utt", type=str2bool, default=False) + group.add_argument("--chunk_size", type=int, default=16) + group.add_argument("--left_context", type=int, default=16) + group.add_argument("--right_context", type=int, default=0) + group.add_argument( + "--display_partial_hypotheses", + type=bool, + default=False, + help="Whether to display partial hypotheses during chunk-by-chunk inference.", + ) + + group = parser.add_argument_group("Dynamic quantization related") + group.add_argument( + "--quantize_asr_model", + type=bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + group.add_argument( + "--quantize_modules", + nargs="*", + default=None, + help="""Module names to apply dynamic quantization on. + The module names are provided as a list, where each name is separated + by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). + Each specified name should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + group.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + group.add_argument("--token_num_relax", type=int, default=1, help="") + group.add_argument("--decoding_ind", type=int, default=0, help="") + group.add_argument("--decoding_mode", type=str, default="model1", help="") + group.add_argument( + "--ctc_weight2", + type=float, + default=0.0, + help="CTC weight in joint decoding", + ) + return parser diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e93d74037..e1a32c57c 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -675,11 +675,13 @@ def inference_paraformer_vad_punc( beg_idx = end_idx batch = {"speech": speech_j, "speech_lengths": speech_lengths_j} batch = to_device(batch, device=device) - # print("batch: ", speech_j.shape[0]) + beg_asr = time.time() results = speech2text(**batch) end_asr = time.time() - # print("time cost asr: ", end_asr - beg_asr) + if speech2text.device != "cpu": + print("batch: ", speech_j.shape[0]) + print("time cost asr: ", end_asr - beg_asr) if len(results) < 1: results = [["", [], [], [], [], [], []]] @@ -2218,259 +2220,9 @@ def inference_launch(**kwargs): logging.info("Unknown decoding mode: {}".format(mode)) return None - -def get_parser(): - parser = config_argparse.ArgumentParser( - description="ASR Decoding", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. - parser.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - parser.add_argument( - "--njob", - type=int, - default=1, - help="The number of jobs for each gpu", - ) - parser.add_argument( - "--gpuid_list", - type=str, - default="", - help="The visible gpus", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - parser.add_argument( - "--dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type", - ) - parser.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - - group = parser.add_argument_group("Input data related") - group.add_argument( - "--data_path_and_name_and_type", - type=str2triple_str, - required=True, - action="append", - ) - group.add_argument("--key_file", type=str_or_none) - parser.add_argument( - "--hotword", - type=str_or_none, - default=None, - help="hotword file path or hotwords seperated by space" - ) - group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - group.add_argument( - "--mc", - type=bool, - default=False, - help="MultiChannel input", - ) - - group = parser.add_argument_group("The model configuration related") - group.add_argument( - "--vad_infer_config", - type=str, - help="VAD infer configuration", - ) - group.add_argument( - "--vad_model_file", - type=str, - help="VAD model parameter file", - ) - group.add_argument( - "--punc_infer_config", - type=str, - help="PUNC infer configuration", - ) - group.add_argument( - "--punc_model_file", - type=str, - help="PUNC model parameter file", - ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global CMVN file", - ) - group.add_argument( - "--asr_train_config", - type=str, - help="ASR training configuration", - ) - group.add_argument( - "--asr_model_file", - type=str, - help="ASR model parameter file", - ) - group.add_argument( - "--sv_model_file", - type=str, - help="SV model parameter file", - ) - group.add_argument( - "--lm_train_config", - type=str, - help="LM training configuration", - ) - group.add_argument( - "--lm_file", - type=str, - help="LM parameter file", - ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) - group.add_argument( - "--model_tag", - type=str, - help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", - ) - group.add_argument( - "--beam_search_config", - default={}, - help="The keyword arguments for transducer beam search.", - ) - - group = parser.add_argument_group("Beam-search related") - group.add_argument( - "--batch_size", - type=int, - default=1, - help="The batch size for inference", - ) - group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.0, - help="CTC weight in joint decoding", - ) - group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - group.add_argument("--fake_streaming", type=str2bool, default=False) - group.add_argument("--full_utt", type=str2bool, default=False) - group.add_argument("--chunk_size", type=int, default=16) - group.add_argument("--left_context", type=int, default=16) - group.add_argument("--right_context", type=int, default=0) - group.add_argument( - "--display_partial_hypotheses", - type=bool, - default=False, - help="Whether to display partial hypotheses during chunk-by-chunk inference.", - ) - - group = parser.add_argument_group("Dynamic quantization related") - group.add_argument( - "--quantize_asr_model", - type=bool, - default=False, - help="Apply dynamic quantization to ASR model.", - ) - group.add_argument( - "--quantize_modules", - nargs="*", - default=None, - help="""Module names to apply dynamic quantization on. - The module names are provided as a list, where each name is separated - by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). - Each specified name should be an attribute of 'torch.nn', e.g.: - torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", - ) - group.add_argument( - "--quantize_dtype", - type=str, - default="qint8", - choices=["float16", "qint8"], - help="Dtype for dynamic quantization.", - ) - - group = parser.add_argument_group("Text converter related") - group.add_argument( - "--token_type", - type=str_or_none, - default=None, - choices=["char", "bpe", None], - help="The token type for ASR model. " - "If not given, refers from the training args", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model path of sentencepiece. " - "If not given, refers from the training args", - ) - group.add_argument("--token_num_relax", type=int, default=1, help="") - group.add_argument("--decoding_ind", type=int, default=0, help="") - group.add_argument("--decoding_mode", type=str, default="model1", help="") - group.add_argument( - "--ctc_weight2", - type=float, - default=0.0, - help="CTC weight in joint decoding", - ) - return parser - - def main(cmd=None): print(get_commandline_args(), file=sys.stderr) + from funasr.bin.argument import get_parser parser = get_parser() parser.add_argument( "--mode", diff --git a/funasr/bin/inference_cli.py b/funasr/bin/inference_cli.py new file mode 100644 index 000000000..f4c66f141 --- /dev/null +++ b/funasr/bin/inference_cli.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os + +import logging +import torch +import numpy as np +from funasr.utils.download_and_prepare_model import prepare_model + +from funasr.utils.types import str2bool + +def infer(task_name: str = "asr", + model: str = None, + # mode: str = None, + vad_model: str = None, + disable_vad: bool = False, + punc_model: str = None, + disable_punc: bool = False, + model_hub: str = "ms", + cache_dir: str = None, + **kwargs, + ): + + # set logging messages + logging.basicConfig( + level=logging.ERROR, + ) + + model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs) + if task_name == "asr": + from funasr.bin.asr_inference_launch import inference_launch + + inference_pipeline = inference_launch(**kwargs) + elif task_name == "": + pipeline = 1 + elif task_name == "": + pipeline = 2 + elif task_name == "": + pipeline = 2 + + def _infer_fn(input, **kwargs): + data_type = kwargs.get('data_type', 'sound') + data_path_and_name_and_type = [input, 'speech', data_type] + raw_inputs = None + if isinstance(input, torch.Tensor): + input = input.numpy() + if isinstance(input, np.ndarray): + data_path_and_name_and_type = None + raw_inputs = input + + return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs) + + return _infer_fn + + +def main(cmd=None): + # print(get_commandline_args(), file=sys.stderr) + from funasr.bin.argument import get_parser + + parser = get_parser() + parser.add_argument('input', help='input file to transcribe') + parser.add_argument( + "--task_name", + type=str, + default="asr", + help="The decoding mode", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="paraformer-zh", + help="The asr mode name", + ) + parser.add_argument( + "-v", + "--vad_model", + type=str, + default="fsmn-vad", + help="vad model name", + ) + parser.add_argument( + "-dv", + "--disable_vad", + type=str2bool, + default=False, + help="", + ) + parser.add_argument( + "-p", + "--punc_model", + type=str, + default="ct-punc", + help="", + ) + parser.add_argument( + "-dp", + "--disable_punc", + type=str2bool, + default=False, + help="", + ) + parser.add_argument( + "--batch_size_token", + type=int, + default=5000, + help="", + ) + parser.add_argument( + "--batch_size_token_threshold_s", + type=int, + default=35, + help="", + ) + parser.add_argument( + "--max_single_segment_time", + type=int, + default=5000, + help="", + ) + args = parser.parse_args(cmd) + kwargs = vars(args) + + # set logging messages + logging.basicConfig( + level=logging.ERROR, + ) + logging.info("Decoding args: {}".format(kwargs)) + + # kwargs["ncpu"] = 2 #os.cpu_count() + kwargs.pop("data_path_and_name_and_type") + print("args: {}".format(kwargs)) + p = infer(**kwargs) + + res = p(**kwargs) + print(res) diff --git a/funasr/utils/download_and_prepare_model.py b/funasr/utils/download_and_prepare_model.py new file mode 100644 index 000000000..af8a3f3dc --- /dev/null +++ b/funasr/utils/download_and_prepare_model.py @@ -0,0 +1,93 @@ + +import os +from pathlib import Path +import logging + +name_maps_ms = { + "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", + "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", + "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", + "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", + "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", + "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", +} + +def prepare_model( + model: str = None, + # mode: str = None, + vad_model: str = None, + punc_model: str = None, + model_hub: str = "ms", + cache_dir: str = None, + **kwargs, +): + if not Path(model).exists(): + if model_hub == "ms" or model_hub == "modelscope": + from modelscope.utils.logger import get_logger + + logger = get_logger(log_level=logging.CRITICAL) + logger.setLevel(logging.CRITICAL) + try: + from modelscope.hub.snapshot_download import snapshot_download as download_tool + model = name_maps_ms[model] if model is not None else None + vad_model = name_maps_ms[vad_model] if vad_model is not None else None + punc_model = name_maps_ms[punc_model] if punc_model is not None else None + except: + raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \ + "\npip3 install -U modelscope\n" \ + "For the users in China, you could install with the command:\n" \ + "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple" + + try: + model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None)) + print("asr model have been downloaded to: {}".format(model)) + except: + raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( + model) + + elif model_hub == "hf" or model_hub == "huggingface": + download_tool = 0 + else: + raise "model_hub must be on of ms or hf, but get {}".format(model_hub) + + + if vad_model is not None and not Path(vad_model).exists(): + vad_model = download_tool(vad_model, cache_dir=cache_dir) + print("vad_model have been downloaded to: {}".format(vad_model)) + if punc_model is not None and not Path(punc_model).exists(): + punc_model = download_tool(punc_model, cache_dir=cache_dir) + print("punc_model have been downloaded to: {}".format(punc_model)) + + # asr + kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"), + "asr_model_file": None if model is None else os.path.join(model, "model.pb"), + "asr_train_config": None if model is None else os.path.join(model, "config.yaml"), + }) + mode = kwargs.get("mode", None) + if mode is None: + import json + json_file = os.path.join(model, 'configuration.json') + with open(json_file, 'r') as f: + config_data = json.load(f) + if config_data['task'] == "punctuation": + mode = config_data['model']['punc_model_config']['mode'] + else: + mode = config_data['model']['model_config']['mode'] + if vad_model is not None and "vad" not in mode: + mode = "paraformer_vad" + kwargs["mode"] = mode + # vad + kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"), + "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"), + "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"), + }) + # punc + kwargs.update({ + "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"), + "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"), + }) + + + return model, vad_model, punc_model, kwargs diff --git a/funasr/version.txt b/funasr/version.txt index b60d71966..7ada0d303 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -0.8.4 +0.8.5 diff --git a/setup.py b/setup.py index 069e39408..dd485d3cd 100644 --- a/setup.py +++ b/setup.py @@ -129,4 +129,7 @@ setup( "License :: OSI Approved :: Apache Software License", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"console_scripts": [ + "funasr = funasr.bin.inference_cli:main", + ]}, )