diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 8acd37cca..c6f19b66b 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -1,3 +1,4 @@ +import argparse import logging import os import sys @@ -9,6 +10,7 @@ from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.pytorch_version import pytorch_cudnn_version from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse +from funasr.utils.build_args import build_args from funasr.utils.build_dataloader import build_dataloader from funasr.utils.build_distributed import build_distributed from funasr.utils.build_model import build_model @@ -272,6 +274,12 @@ def get_parser(): action="append", default=[], ) + parser.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) # pai related parser.add_argument( @@ -330,6 +338,8 @@ def get_parser(): if __name__ == '__main__': parser = get_parser() args = parser.parse_args() + task_args = build_args(args) + args = argparse.Namespace(**vars(args), **vars(task_args)) # set random seed set_all_random_seed(args.seed) diff --git a/funasr/utils/build_args.py b/funasr/utils/build_args.py new file mode 100644 index 000000000..1baf2d681 --- /dev/null +++ b/funasr/utils/build_args.py @@ -0,0 +1,87 @@ +import argparse + +from funasr.models.ctc import CTC +from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.utils.nested_dict_action import NestedDictAction +from funasr.utils.types import int_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + + +def build_args(args): + parser = argparse.ArgumentParser("Task related config") + if args.task_name == "asr": + from funasr.utils.build_asr_model import class_choices_list + for class_choices in class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(parser) + parser.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + parser.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + parser.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + parser.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + parser.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + parser.add_argument( + "--ctc_conf", + action=NestedDictAction, + default=get_default_kwargs(CTC), + help="The keyword arguments for CTC class.", + ) + parser.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + parser.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + + else: + raise NotImplementedError("Not supported task: {}".format(args.task_name)) + + args = parser.parse_args() + return args