diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 26e0e6a26..9c8f672e7 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -19,6 +19,7 @@ from funasr.text.phoneme_tokenizer import g2p_choices from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.pytorch_version import pytorch_cudnn_version from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils.nested_dict_action import NestedDictAction from funasr.utils.prepare_data import prepare_data from funasr.utils.types import str2bool from funasr.utils.types import str_or_none @@ -302,6 +303,32 @@ def get_parser(): help="Apply preprocessing to data or not", ) + # optimization related + parser.add_argument( + "--optim", + type=lambda x: x.lower(), + default="adam", + help="The optimizer type", + ) + parser.add_argument( + "--optim_conf", + action=NestedDictAction, + default=dict(), + help="The keyword arguments for optimizer", + ) + parser.add_argument( + "--scheduler", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The lr scheduler type", + ) + parser.add_argument( + "--scheduler_conf", + action=NestedDictAction, + default=dict(), + help="The keyword arguments for lr scheduler", + ) + # most task related parser.add_argument( "--init",