diff --git a/funasr/bin/train.py b/funasr/bin/train.py index c173167ca..e861199cd 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -6,18 +6,20 @@ from io import BytesIO import torch -from funasr.torch_utils.model_summary import model_summary -from funasr.torch_utils.pytorch_version import pytorch_cudnn_version -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse from funasr.build_utils.build_args import build_args from funasr.build_utils.build_dataloader import build_dataloader from funasr.build_utils.build_distributed import build_distributed from funasr.build_utils.build_model import build_model from funasr.build_utils.build_optimizer import build_optimizer from funasr.build_utils.build_scheduler import build_scheduler +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.torch_utils.model_summary import model_summary +from funasr.torch_utils.pytorch_version import pytorch_cudnn_version +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse from funasr.utils.prepare_data import prepare_data from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump @@ -281,6 +283,55 @@ def get_parser(): help="Apply preprocessing to data or not", ) + # most task related + parser.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + parser.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + parser.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word"], + help="", + ) + parser.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file fo sentencepiece", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="Specify g2p method if --token_type=phn", + ) + # pai related parser.add_argument( "--use_pai", diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py index 91f28101f..fc737ba98 100644 --- a/funasr/build_utils/build_args.py +++ b/funasr/build_utils/build_args.py @@ -16,12 +16,6 @@ def build_args(args): # Append -- and --_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(parser) - parser.add_argument( - "--token_list", - type=str_or_none, - default=None, - help="A text mapping int-id to token", - ) parser.add_argument( "--split_with_space", type=str2bool, @@ -34,20 +28,6 @@ def build_args(args): default=None, help="seg_dict_file for text processing", ) - parser.add_argument( - "--init", - type=lambda x: str_or_none(x.lower()), - default=None, - help="The initialization method", - choices=[ - "chainer", - "xavier_uniform", - "xavier_normal", - "kaiming_uniform", - "kaiming_normal", - None, - ], - ) parser.add_argument( "--input_size", type=int_or_none, @@ -60,134 +40,40 @@ def build_args(args): default=get_default_kwargs(CTC), help="The keyword arguments for CTC class.", ) - parser.add_argument( - "--token_type", - type=str, - default="bpe", - choices=["bpe", "char", "word", "phn"], - help="The text will be tokenized " "in the specified level token", - ) - parser.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model file of sentencepiece", - ) - parser.add_argument( - "--cleaner", - type=str_or_none, - choices=[None, "tacotron", "jaconv", "vietnamese"], - default=None, - help="Apply text cleaning", - ) parser.add_argument( "--cmvn_file", type=str_or_none, default=None, help="The file path of noise scp file.", ) + elif args.task_name == "pretrain": from funasr.build_utils.build_pretrain_model import class_choices_list for class_choices in class_choices_list: # Append -- and --_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(parser) - parser.add_argument( - "--init", - type=lambda x: str_or_none(x.lower()), - default=None, - help="The initialization method", - choices=[ - "chainer", - "xavier_uniform", - "xavier_normal", - "kaiming_uniform", - "kaiming_normal", - None, - ], - ) parser.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) - parser.add_argument( - "--feats_type", - type=str, - default='fbank', - help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)", - ) - parser.add_argument( - "--noise_db_range", - type=str, - default="13_15", - help="The range of noise decibel level.", - ) - parser.add_argument( - "--pred_masked_weight", - type=float, - default=1.0, - help="weight for predictive loss for masked frames", - ) - parser.add_argument( - "--pred_nomask_weight", - type=float, - default=0.0, - help="weight for predictive loss for unmasked frames", - ) - parser.add_argument( - "--loss_weights", - type=float, - default=0.0, - help="weights for additional loss terms (not first one)", - ) + elif args.task_name == "lm": from funasr.build_utils.build_lm_model import class_choices_list for class_choices in class_choices_list: # Append -- and --_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(parser) - parser.add_argument( - "--token_list", - type=str_or_none, - default=None, - help="A text mapping int-id to token", - ) - parser.add_argument( - "--init", - type=lambda x: str_or_none(x.lower()), - default=None, - help="The initialization method", - choices=[ - "chainer", - "xavier_uniform", - "xavier_normal", - "kaiming_uniform", - "kaiming_normal", - None, - ], - ) - parser.add_argument( - "--token_type", - type=str, - default="bpe", - choices=["bpe", "char", "word"], - help="", - ) - parser.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model file fo sentencepiece", - ) - parser.add_argument( - "--cleaner", - type=str_or_none, - choices=[None, "tacotron", "jaconv", "vietnamese"], - default=None, - help="Apply text cleaning", - ) + + elif args.task_name == "punc": + from funasr.build_utils.build_punc_model import class_choices_list + for class_choices in class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(parser) + else: raise NotImplementedError("Not supported task: {}".format(args.task_name)) diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py index 822263191..b1d123020 100644 --- a/funasr/build_utils/build_model.py +++ b/funasr/build_utils/build_model.py @@ -1,6 +1,7 @@ from funasr.build_utils.build_asr_model import build_asr_model from funasr.build_utils.build_pretrain_model import build_pretrain_model from funasr.build_utils.build_lm_model import build_lm_model +from funasr.build_utils.build_punc_model import build_punc_model def build_model(args): @@ -10,6 +11,8 @@ def build_model(args): model = build_pretrain_model(args) elif args.task_name == "lm": model = build_lm_model(args) + elif args.task_name == "punc": + model = build_punc_model(args) else: raise NotImplementedError("Not supported task: {}".format(args.task_name)) diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py new file mode 100644 index 000000000..d098ffca0 --- /dev/null +++ b/funasr/build_utils/build_punc_model.py @@ -0,0 +1,67 @@ +import logging + +from funasr.models.target_delay_transformer import TargetDelayTransformer +from funasr.models.vad_realtime_transformer import VadRealtimeTransformer +from funasr.torch_utils.initialize import initialize +from funasr.train.abs_model import AbsPunctuation +from funasr.train.abs_model import PunctuationModel +from funasr.train.class_choices import ClassChoices + +punc_choices = ClassChoices( + "punctuation", + classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer), + type_check=AbsPunctuation, + default="target_delay", +) +model_choices = ClassChoices( + "model", + classes=dict( + punc=PunctuationModel, + ), + default="punc", +) +class_choices_list = [ + # --punc and --punc_conf + punc_choices, + # --model and --model_conf + model_choices +] + + +def build_punc_model(args): + # token_list and punc list + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + args.token_list = token_list.copy() + if isinstance(args.punc_list, str): + with open(args.punc_list, encoding="utf-8") as f2: + pairs = [line.rstrip().split(":") for line in f2] + punc_list = [pair[0] for pair in pairs] + punc_weight_list = [float(pair[1]) for pair in pairs] + args.punc_list = punc_list.copy() + elif isinstance(args.punc_list, list): + punc_list = args.punc_list.copy() + punc_weight_list = [1] * len(punc_list) + if isinstance(args.token_list, (tuple, list)): + token_list = args.token_list.copy() + else: + raise RuntimeError("token_list must be str or dict") + + vocab_size = len(token_list) + punc_size = len(punc_list) + logging.info(f"Vocabulary size: {vocab_size}") + + # punc + punc_class = punc_choices.get_class(args.punctuation) + punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf) + + if "punc_weight" in args.model_conf: + args.model_conf.pop("punc_weight") + model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) + + # initialize + if args.init is not None: + initialize(model, args.init) + + return model