rnn_decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. import random
  7. import numpy as np
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from funasr.register import tables
  11. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  12. from funasr.models.transformer.utils.nets_utils import to_device
  13. from funasr.models.language_model.rnn.attentions import initial_att
  14. def build_attention_list(
  15. eprojs: int,
  16. dunits: int,
  17. atype: str = "location",
  18. num_att: int = 1,
  19. num_encs: int = 1,
  20. aheads: int = 4,
  21. adim: int = 320,
  22. awin: int = 5,
  23. aconv_chans: int = 10,
  24. aconv_filts: int = 100,
  25. han_mode: bool = False,
  26. han_type=None,
  27. han_heads: int = 4,
  28. han_dim: int = 320,
  29. han_conv_chans: int = -1,
  30. han_conv_filts: int = 100,
  31. han_win: int = 5,
  32. ):
  33. att_list = torch.nn.ModuleList()
  34. if num_encs == 1:
  35. for i in range(num_att):
  36. att = initial_att(
  37. atype,
  38. eprojs,
  39. dunits,
  40. aheads,
  41. adim,
  42. awin,
  43. aconv_chans,
  44. aconv_filts,
  45. )
  46. att_list.append(att)
  47. elif num_encs > 1: # no multi-speaker mode
  48. if han_mode:
  49. att = initial_att(
  50. han_type,
  51. eprojs,
  52. dunits,
  53. han_heads,
  54. han_dim,
  55. han_win,
  56. han_conv_chans,
  57. han_conv_filts,
  58. han_mode=True,
  59. )
  60. return att
  61. else:
  62. att_list = torch.nn.ModuleList()
  63. for idx in range(num_encs):
  64. att = initial_att(
  65. atype[idx],
  66. eprojs,
  67. dunits,
  68. aheads[idx],
  69. adim[idx],
  70. awin[idx],
  71. aconv_chans[idx],
  72. aconv_filts[idx],
  73. )
  74. att_list.append(att)
  75. else:
  76. raise ValueError(
  77. "Number of encoders needs to be more than one. {}".format(num_encs)
  78. )
  79. return att_list
  80. @tables.register("decoder_classes", "rnn_decoder")
  81. class RNNDecoder(nn.Module):
  82. def __init__(
  83. self,
  84. vocab_size: int,
  85. encoder_output_size: int,
  86. rnn_type: str = "lstm",
  87. num_layers: int = 1,
  88. hidden_size: int = 320,
  89. sampling_probability: float = 0.0,
  90. dropout: float = 0.0,
  91. context_residual: bool = False,
  92. replace_sos: bool = False,
  93. num_encs: int = 1,
  94. att_conf: dict = None,
  95. ):
  96. # FIXME(kamo): The parts of num_spk should be refactored more more more
  97. if rnn_type not in {"lstm", "gru"}:
  98. raise ValueError(f"Not supported: rnn_type={rnn_type}")
  99. super().__init__()
  100. eprojs = encoder_output_size
  101. self.dtype = rnn_type
  102. self.dunits = hidden_size
  103. self.dlayers = num_layers
  104. self.context_residual = context_residual
  105. self.sos = vocab_size - 1
  106. self.eos = vocab_size - 1
  107. self.odim = vocab_size
  108. self.sampling_probability = sampling_probability
  109. self.dropout = dropout
  110. self.num_encs = num_encs
  111. # for multilingual translation
  112. self.replace_sos = replace_sos
  113. self.embed = torch.nn.Embedding(vocab_size, hidden_size)
  114. self.dropout_emb = torch.nn.Dropout(p=dropout)
  115. self.decoder = torch.nn.ModuleList()
  116. self.dropout_dec = torch.nn.ModuleList()
  117. self.decoder += [
  118. torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
  119. if self.dtype == "lstm"
  120. else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
  121. ]
  122. self.dropout_dec += [torch.nn.Dropout(p=dropout)]
  123. for _ in range(1, self.dlayers):
  124. self.decoder += [
  125. torch.nn.LSTMCell(hidden_size, hidden_size)
  126. if self.dtype == "lstm"
  127. else torch.nn.GRUCell(hidden_size, hidden_size)
  128. ]
  129. self.dropout_dec += [torch.nn.Dropout(p=dropout)]
  130. # NOTE: dropout is applied only for the vertical connections
  131. # see https://arxiv.org/pdf/1409.2329.pdf
  132. if context_residual:
  133. self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
  134. else:
  135. self.output = torch.nn.Linear(hidden_size, vocab_size)
  136. self.att_list = build_attention_list(
  137. eprojs=eprojs, dunits=hidden_size, **att_conf
  138. )
  139. def zero_state(self, hs_pad):
  140. return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
  141. def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
  142. if self.dtype == "lstm":
  143. z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
  144. for i in range(1, self.dlayers):
  145. z_list[i], c_list[i] = self.decoder[i](
  146. self.dropout_dec[i - 1](z_list[i - 1]),
  147. (z_prev[i], c_prev[i]),
  148. )
  149. else:
  150. z_list[0] = self.decoder[0](ey, z_prev[0])
  151. for i in range(1, self.dlayers):
  152. z_list[i] = self.decoder[i](
  153. self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
  154. )
  155. return z_list, c_list
  156. def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
  157. # to support mutiple encoder asr mode, in single encoder mode,
  158. # convert torch.Tensor to List of torch.Tensor
  159. if self.num_encs == 1:
  160. hs_pad = [hs_pad]
  161. hlens = [hlens]
  162. # attention index for the attention module
  163. # in SPA (speaker parallel attention),
  164. # att_idx is used to select attention module. In other cases, it is 0.
  165. att_idx = min(strm_idx, len(self.att_list) - 1)
  166. # hlens should be list of list of integer
  167. hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
  168. # get dim, length info
  169. olength = ys_in_pad.size(1)
  170. # initialization
  171. c_list = [self.zero_state(hs_pad[0])]
  172. z_list = [self.zero_state(hs_pad[0])]
  173. for _ in range(1, self.dlayers):
  174. c_list.append(self.zero_state(hs_pad[0]))
  175. z_list.append(self.zero_state(hs_pad[0]))
  176. z_all = []
  177. if self.num_encs == 1:
  178. att_w = None
  179. self.att_list[att_idx].reset() # reset pre-computation of h
  180. else:
  181. att_w_list = [None] * (self.num_encs + 1) # atts + han
  182. att_c_list = [None] * self.num_encs # atts
  183. for idx in range(self.num_encs + 1):
  184. # reset pre-computation of h in atts and han
  185. self.att_list[idx].reset()
  186. # pre-computation of embedding
  187. eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
  188. # loop for an output sequence
  189. for i in range(olength):
  190. if self.num_encs == 1:
  191. att_c, att_w = self.att_list[att_idx](
  192. hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
  193. )
  194. else:
  195. for idx in range(self.num_encs):
  196. att_c_list[idx], att_w_list[idx] = self.att_list[idx](
  197. hs_pad[idx],
  198. hlens[idx],
  199. self.dropout_dec[0](z_list[0]),
  200. att_w_list[idx],
  201. )
  202. hs_pad_han = torch.stack(att_c_list, dim=1)
  203. hlens_han = [self.num_encs] * len(ys_in_pad)
  204. att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
  205. hs_pad_han,
  206. hlens_han,
  207. self.dropout_dec[0](z_list[0]),
  208. att_w_list[self.num_encs],
  209. )
  210. if i > 0 and random.random() < self.sampling_probability:
  211. z_out = self.output(z_all[-1])
  212. z_out = np.argmax(z_out.detach().cpu(), axis=1)
  213. z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
  214. ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
  215. else:
  216. # utt x (zdim + hdim)
  217. ey = torch.cat((eys[:, i, :], att_c), dim=1)
  218. z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
  219. if self.context_residual:
  220. z_all.append(
  221. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  222. ) # utt x (zdim + hdim)
  223. else:
  224. z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
  225. z_all = torch.stack(z_all, dim=1)
  226. z_all = self.output(z_all)
  227. z_all.masked_fill_(
  228. make_pad_mask(ys_in_lens, z_all, 1),
  229. 0,
  230. )
  231. return z_all, ys_in_lens
  232. def init_state(self, x):
  233. # to support mutiple encoder asr mode, in single encoder mode,
  234. # convert torch.Tensor to List of torch.Tensor
  235. if self.num_encs == 1:
  236. x = [x]
  237. c_list = [self.zero_state(x[0].unsqueeze(0))]
  238. z_list = [self.zero_state(x[0].unsqueeze(0))]
  239. for _ in range(1, self.dlayers):
  240. c_list.append(self.zero_state(x[0].unsqueeze(0)))
  241. z_list.append(self.zero_state(x[0].unsqueeze(0)))
  242. # TODO(karita): support strm_index for `asr_mix`
  243. strm_index = 0
  244. att_idx = min(strm_index, len(self.att_list) - 1)
  245. if self.num_encs == 1:
  246. a = None
  247. self.att_list[att_idx].reset() # reset pre-computation of h
  248. else:
  249. a = [None] * (self.num_encs + 1) # atts + han
  250. for idx in range(self.num_encs + 1):
  251. # reset pre-computation of h in atts and han
  252. self.att_list[idx].reset()
  253. return dict(
  254. c_prev=c_list[:],
  255. z_prev=z_list[:],
  256. a_prev=a,
  257. workspace=(att_idx, z_list, c_list),
  258. )
  259. def score(self, yseq, state, x):
  260. # to support mutiple encoder asr mode, in single encoder mode,
  261. # convert torch.Tensor to List of torch.Tensor
  262. if self.num_encs == 1:
  263. x = [x]
  264. att_idx, z_list, c_list = state["workspace"]
  265. vy = yseq[-1].unsqueeze(0)
  266. ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
  267. if self.num_encs == 1:
  268. att_c, att_w = self.att_list[att_idx](
  269. x[0].unsqueeze(0),
  270. [x[0].size(0)],
  271. self.dropout_dec[0](state["z_prev"][0]),
  272. state["a_prev"],
  273. )
  274. else:
  275. att_w = [None] * (self.num_encs + 1) # atts + han
  276. att_c_list = [None] * self.num_encs # atts
  277. for idx in range(self.num_encs):
  278. att_c_list[idx], att_w[idx] = self.att_list[idx](
  279. x[idx].unsqueeze(0),
  280. [x[idx].size(0)],
  281. self.dropout_dec[0](state["z_prev"][0]),
  282. state["a_prev"][idx],
  283. )
  284. h_han = torch.stack(att_c_list, dim=1)
  285. att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
  286. h_han,
  287. [self.num_encs],
  288. self.dropout_dec[0](state["z_prev"][0]),
  289. state["a_prev"][self.num_encs],
  290. )
  291. ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
  292. z_list, c_list = self.rnn_forward(
  293. ey, z_list, c_list, state["z_prev"], state["c_prev"]
  294. )
  295. if self.context_residual:
  296. logits = self.output(
  297. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  298. )
  299. else:
  300. logits = self.output(self.dropout_dec[-1](z_list[-1]))
  301. logp = F.log_softmax(logits, dim=1).squeeze(0)
  302. return (
  303. logp,
  304. dict(
  305. c_prev=c_list[:],
  306. z_prev=z_list[:],
  307. a_prev=att_w,
  308. workspace=(att_idx, z_list, c_list),
  309. ),
  310. )