diff --git a/funasr/bin/train.py b/funasr/bin/train.py index d3ebaacbf..3e3f5987d 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -241,6 +241,12 @@ def get_parser(): default=False, help="Enable resuming if checkpoint is existing", ) + parser.add_argument( + "--train_dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type for training.", + ) parser.add_argument( "--use_amp", type=str2bool,