| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # Copyright 2020 Hirofumi Inaguma
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Conformer common arguments."""
- def add_arguments_rnn_encoder_common(group):
- """Define common arguments for RNN encoder."""
- group.add_argument(
- "--etype",
- default="blstmp",
- type=str,
- choices=[
- "lstm",
- "blstm",
- "lstmp",
- "blstmp",
- "vgglstmp",
- "vggblstmp",
- "vgglstm",
- "vggblstm",
- "gru",
- "bgru",
- "grup",
- "bgrup",
- "vgggrup",
- "vggbgrup",
- "vgggru",
- "vggbgru",
- ],
- help="Type of encoder network architecture",
- )
- group.add_argument(
- "--elayers",
- default=4,
- type=int,
- help="Number of encoder layers",
- )
- group.add_argument(
- "--eunits",
- "-u",
- default=300,
- type=int,
- help="Number of encoder hidden units",
- )
- group.add_argument(
- "--eprojs", default=320, type=int, help="Number of encoder projection units"
- )
- group.add_argument(
- "--subsample",
- default="1",
- type=str,
- help="Subsample input frames x_y_z means "
- "subsample every x frame at 1st layer, "
- "every y frame at 2nd layer etc.",
- )
- return group
- def add_arguments_rnn_decoder_common(group):
- """Define common arguments for RNN decoder."""
- group.add_argument(
- "--dtype",
- default="lstm",
- type=str,
- choices=["lstm", "gru"],
- help="Type of decoder network architecture",
- )
- group.add_argument(
- "--dlayers", default=1, type=int, help="Number of decoder layers"
- )
- group.add_argument(
- "--dunits", default=320, type=int, help="Number of decoder hidden units"
- )
- group.add_argument(
- "--dropout-rate-decoder",
- default=0.0,
- type=float,
- help="Dropout rate for the decoder",
- )
- group.add_argument(
- "--sampling-probability",
- default=0.0,
- type=float,
- help="Ratio of predicted labels fed back to decoder",
- )
- group.add_argument(
- "--lsm-type",
- const="",
- default="",
- type=str,
- nargs="?",
- choices=["", "unigram"],
- help="Apply label smoothing with a specified distribution type",
- )
- return group
- def add_arguments_rnn_attention_common(group):
- """Define common arguments for RNN attention."""
- group.add_argument(
- "--atype",
- default="dot",
- type=str,
- choices=[
- "noatt",
- "dot",
- "add",
- "location",
- "coverage",
- "coverage_location",
- "location2d",
- "location_recurrent",
- "multi_head_dot",
- "multi_head_add",
- "multi_head_loc",
- "multi_head_multi_res_loc",
- ],
- help="Type of attention architecture",
- )
- group.add_argument(
- "--adim",
- default=320,
- type=int,
- help="Number of attention transformation dimensions",
- )
- group.add_argument(
- "--awin", default=5, type=int, help="Window size for location2d attention"
- )
- group.add_argument(
- "--aheads",
- default=4,
- type=int,
- help="Number of heads for multi head attention",
- )
- group.add_argument(
- "--aconv-chans",
- default=-1,
- type=int,
- help="Number of attention convolution channels \
- (negative value indicates no location-aware attention)",
- )
- group.add_argument(
- "--aconv-filts",
- default=100,
- type=int,
- help="Number of attention convolution filters \
- (negative value indicates no location-aware attention)",
- )
- group.add_argument(
- "--dropout-rate",
- default=0.0,
- type=float,
- help="Dropout rate for the encoder",
- )
- return group
|