#!/usr/bin/env python3 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import argparse import logging import os import sys from typing import Union, Dict, Any from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none def get_parser(): parser = config_argparse.ArgumentParser( description="Speaker Verification", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Note(kamo): Use '_' instead of '-' as separator. # '-' is confusing if written in yaml. parser.add_argument( "--log_level", type=lambda x: x.upper(), default="INFO", choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), help="The verbose level of logging", ) parser.add_argument("--output_dir", type=str, required=False) parser.add_argument( "--ngpu", type=int, default=0, help="The number of gpus. 0 indicates CPU mode", ) parser.add_argument( "--njob", type=int, default=1, help="The number of jobs for each gpu", ) parser.add_argument( "--gpuid_list", type=str, default="", help="The visible gpus", ) parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--dtype", default="float32", choices=["float16", "float32", "float64"], help="Data type", ) parser.add_argument( "--num_workers", type=int, default=1, help="The number of workers used for DataLoader", ) group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", type=str2triple_str, required=False, action="append", ) group.add_argument("--key_file", type=str_or_none) group.add_argument("--allow_variable_data_keys", type=str2bool, default=True) group = parser.add_argument_group("The model configuration related") group.add_argument( "--vad_infer_config", type=str, help="VAD infer configuration", ) group.add_argument( "--vad_model_file", type=str, help="VAD model parameter file", ) group.add_argument( "--diar_train_config", type=str, help="ASR training configuration", ) group.add_argument( "--diar_model_file", type=str, help="ASR model parameter file", ) group.add_argument( "--cmvn_file", type=str, help="Global CMVN file", ) group.add_argument( "--model_tag", type=str, help="Pretrained model tag. If specify this option, *_train_config and " "*_file will be overwritten", ) group = parser.add_argument_group("The inference configuration related") group.add_argument( "--batch_size", type=int, default=1, help="The batch size for inference", ) group.add_argument( "--diar_smooth_size", type=int, default=121, help="The smoothing size for post-processing" ) return parser def inference_launch(mode, **kwargs): if mode == "sond": from funasr.bin.sond_inference import inference_modelscope return inference_modelscope(mode=mode, **kwargs) elif mode == "sond_demo": from funasr.bin.sond_inference import inference_modelscope param_dict = { "extract_profile": True, "sv_train_config": "sv.yaml", "sv_model_file": "sv.pb", } if "param_dict" in kwargs and kwargs["param_dict"] is not None: for key in param_dict: if key not in kwargs["param_dict"]: kwargs["param_dict"][key] = param_dict[key] else: kwargs["param_dict"] = param_dict return inference_modelscope(mode=mode, **kwargs) elif mode == "eend-ola": from funasr.bin.eend_ola_inference import inference_modelscope return inference_modelscope(mode=mode, **kwargs) else: logging.info("Unknown decoding mode: {}".format(mode)) return None def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() parser.add_argument( "--mode", type=str, default="sond", help="The decoding mode", ) args = parser.parse_args(cmd) kwargs = vars(args) kwargs.pop("config", None) # set logging messages logging.basicConfig( level=args.log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.info("Decoding args: {}".format(kwargs)) # gpu setting if args.ngpu > 0: jobid = int(args.output_dir.split(".")[-1]) gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpuid inference_launch(**kwargs) if __name__ == "__main__": main()