diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 0bb056365..5de475f19 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -224,6 +224,15 @@ rnnt_decoder_choices = ClassChoices( default="rnnt", ) +joint_network_choices = ClassChoices( + name="joint_network", + classes=dict( + joint_network=JointNetwork, + ), + default="joint_network", + optional=True, +) + predictor_choices = ClassChoices( name="predictor", classes=dict( @@ -353,7 +362,7 @@ class ASRTask(AbsTask): help="The keyword arguments for CTC class.", ) group.add_argument( - "--joint_net_conf", + "--joint_network_conf", action=NestedDictAction, default=None, help="The keyword arguments for joint network class.",