argument.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2020 Hirofumi Inaguma
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer common arguments."""
  4. def add_arguments_rnn_encoder_common(group):
  5. """Define common arguments for RNN encoder."""
  6. group.add_argument(
  7. "--etype",
  8. default="blstmp",
  9. type=str,
  10. choices=[
  11. "lstm",
  12. "blstm",
  13. "lstmp",
  14. "blstmp",
  15. "vgglstmp",
  16. "vggblstmp",
  17. "vgglstm",
  18. "vggblstm",
  19. "gru",
  20. "bgru",
  21. "grup",
  22. "bgrup",
  23. "vgggrup",
  24. "vggbgrup",
  25. "vgggru",
  26. "vggbgru",
  27. ],
  28. help="Type of encoder network architecture",
  29. )
  30. group.add_argument(
  31. "--elayers",
  32. default=4,
  33. type=int,
  34. help="Number of encoder layers",
  35. )
  36. group.add_argument(
  37. "--eunits",
  38. "-u",
  39. default=300,
  40. type=int,
  41. help="Number of encoder hidden units",
  42. )
  43. group.add_argument(
  44. "--eprojs", default=320, type=int, help="Number of encoder projection units"
  45. )
  46. group.add_argument(
  47. "--subsample",
  48. default="1",
  49. type=str,
  50. help="Subsample input frames x_y_z means "
  51. "subsample every x frame at 1st layer, "
  52. "every y frame at 2nd layer etc.",
  53. )
  54. return group
  55. def add_arguments_rnn_decoder_common(group):
  56. """Define common arguments for RNN decoder."""
  57. group.add_argument(
  58. "--dtype",
  59. default="lstm",
  60. type=str,
  61. choices=["lstm", "gru"],
  62. help="Type of decoder network architecture",
  63. )
  64. group.add_argument(
  65. "--dlayers", default=1, type=int, help="Number of decoder layers"
  66. )
  67. group.add_argument(
  68. "--dunits", default=320, type=int, help="Number of decoder hidden units"
  69. )
  70. group.add_argument(
  71. "--dropout-rate-decoder",
  72. default=0.0,
  73. type=float,
  74. help="Dropout rate for the decoder",
  75. )
  76. group.add_argument(
  77. "--sampling-probability",
  78. default=0.0,
  79. type=float,
  80. help="Ratio of predicted labels fed back to decoder",
  81. )
  82. group.add_argument(
  83. "--lsm-type",
  84. const="",
  85. default="",
  86. type=str,
  87. nargs="?",
  88. choices=["", "unigram"],
  89. help="Apply label smoothing with a specified distribution type",
  90. )
  91. return group
  92. def add_arguments_rnn_attention_common(group):
  93. """Define common arguments for RNN attention."""
  94. group.add_argument(
  95. "--atype",
  96. default="dot",
  97. type=str,
  98. choices=[
  99. "noatt",
  100. "dot",
  101. "add",
  102. "location",
  103. "coverage",
  104. "coverage_location",
  105. "location2d",
  106. "location_recurrent",
  107. "multi_head_dot",
  108. "multi_head_add",
  109. "multi_head_loc",
  110. "multi_head_multi_res_loc",
  111. ],
  112. help="Type of attention architecture",
  113. )
  114. group.add_argument(
  115. "--adim",
  116. default=320,
  117. type=int,
  118. help="Number of attention transformation dimensions",
  119. )
  120. group.add_argument(
  121. "--awin", default=5, type=int, help="Window size for location2d attention"
  122. )
  123. group.add_argument(
  124. "--aheads",
  125. default=4,
  126. type=int,
  127. help="Number of heads for multi head attention",
  128. )
  129. group.add_argument(
  130. "--aconv-chans",
  131. default=-1,
  132. type=int,
  133. help="Number of attention convolution channels \
  134. (negative value indicates no location-aware attention)",
  135. )
  136. group.add_argument(
  137. "--aconv-filts",
  138. default=100,
  139. type=int,
  140. help="Number of attention convolution filters \
  141. (negative value indicates no location-aware attention)",
  142. )
  143. group.add_argument(
  144. "--dropout-rate",
  145. default=0.0,
  146. type=float,
  147. help="Dropout rate for the encoder",
  148. )
  149. return group