FunASR/funasr/models/language_model/rnn/argument.py
zhifu gao 861147c730
Dev gzf exp (#1654)
* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* whisper

* whisper

* update style

* update style

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
2024-04-24 16:03:38 +08:00

149 lines
3.9 KiB
Python

# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
def add_arguments_rnn_encoder_common(group):
"""Define common arguments for RNN encoder."""
group.add_argument(
"--etype",
default="blstmp",
type=str,
choices=[
"lstm",
"blstm",
"lstmp",
"blstmp",
"vgglstmp",
"vggblstmp",
"vgglstm",
"vggblstm",
"gru",
"bgru",
"grup",
"bgrup",
"vgggrup",
"vggbgrup",
"vgggru",
"vggbgru",
],
help="Type of encoder network architecture",
)
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
group.add_argument("--eprojs", default=320, type=int, help="Number of encoder projection units")
group.add_argument(
"--subsample",
default="1",
type=str,
help="Subsample input frames x_y_z means "
"subsample every x frame at 1st layer, "
"every y frame at 2nd layer etc.",
)
return group
def add_arguments_rnn_decoder_common(group):
"""Define common arguments for RNN decoder."""
group.add_argument(
"--dtype",
default="lstm",
type=str,
choices=["lstm", "gru"],
help="Type of decoder network architecture",
)
group.add_argument("--dlayers", default=1, type=int, help="Number of decoder layers")
group.add_argument("--dunits", default=320, type=int, help="Number of decoder hidden units")
group.add_argument(
"--dropout-rate-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder",
)
group.add_argument(
"--sampling-probability",
default=0.0,
type=float,
help="Ratio of predicted labels fed back to decoder",
)
group.add_argument(
"--lsm-type",
const="",
default="",
type=str,
nargs="?",
choices=["", "unigram"],
help="Apply label smoothing with a specified distribution type",
)
return group
def add_arguments_rnn_attention_common(group):
"""Define common arguments for RNN attention."""
group.add_argument(
"--atype",
default="dot",
type=str,
choices=[
"noatt",
"dot",
"add",
"location",
"coverage",
"coverage_location",
"location2d",
"location_recurrent",
"multi_head_dot",
"multi_head_add",
"multi_head_loc",
"multi_head_multi_res_loc",
],
help="Type of attention architecture",
)
group.add_argument(
"--adim",
default=320,
type=int,
help="Number of attention transformation dimensions",
)
group.add_argument("--awin", default=5, type=int, help="Window size for location2d attention")
group.add_argument(
"--aheads",
default=4,
type=int,
help="Number of heads for multi head attention",
)
group.add_argument(
"--aconv-chans",
default=-1,
type=int,
help="Number of attention convolution channels \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--aconv-filts",
default=100,
type=int,
help="Number of attention convolution filters \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
return group