decoders.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211
  1. """RNN decoder module."""
  2. import logging
  3. import math
  4. import random
  5. from argparse import Namespace
  6. import numpy as np
  7. import six
  8. import torch
  9. import torch.nn.functional as F
  10. from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScore
  11. from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScoreTH
  12. from funasr.modules.scorers.scorer_interface import ScorerInterface
  13. from funasr.modules.e2e_asr_common import end_detect
  14. from funasr.modules.nets_utils import mask_by_length
  15. from funasr.modules.nets_utils import pad_list
  16. from funasr.modules.nets_utils import th_accuracy
  17. from funasr.modules.nets_utils import to_device
  18. from funasr.modules.rnn.attentions import att_to_numpy
  19. MAX_DECODER_OUTPUT = 5
  20. CTC_SCORING_RATIO = 1.5
  21. class Decoder(torch.nn.Module, ScorerInterface):
  22. """Decoder module
  23. :param int eprojs: encoder projection units
  24. :param int odim: dimension of outputs
  25. :param str dtype: gru or lstm
  26. :param int dlayers: decoder layers
  27. :param int dunits: decoder units
  28. :param int sos: start of sequence symbol id
  29. :param int eos: end of sequence symbol id
  30. :param torch.nn.Module att: attention module
  31. :param int verbose: verbose level
  32. :param list char_list: list of character strings
  33. :param ndarray labeldist: distribution of label smoothing
  34. :param float lsm_weight: label smoothing weight
  35. :param float sampling_probability: scheduled sampling probability
  36. :param float dropout: dropout rate
  37. :param float context_residual: if True, use context vector for token generation
  38. :param float replace_sos: use for multilingual (speech/text) translation
  39. """
  40. def __init__(
  41. self,
  42. eprojs,
  43. odim,
  44. dtype,
  45. dlayers,
  46. dunits,
  47. sos,
  48. eos,
  49. att,
  50. verbose=0,
  51. char_list=None,
  52. labeldist=None,
  53. lsm_weight=0.0,
  54. sampling_probability=0.0,
  55. dropout=0.0,
  56. context_residual=False,
  57. replace_sos=False,
  58. num_encs=1,
  59. ):
  60. torch.nn.Module.__init__(self)
  61. self.dtype = dtype
  62. self.dunits = dunits
  63. self.dlayers = dlayers
  64. self.context_residual = context_residual
  65. self.embed = torch.nn.Embedding(odim, dunits)
  66. self.dropout_emb = torch.nn.Dropout(p=dropout)
  67. self.decoder = torch.nn.ModuleList()
  68. self.dropout_dec = torch.nn.ModuleList()
  69. self.decoder += [
  70. torch.nn.LSTMCell(dunits + eprojs, dunits)
  71. if self.dtype == "lstm"
  72. else torch.nn.GRUCell(dunits + eprojs, dunits)
  73. ]
  74. self.dropout_dec += [torch.nn.Dropout(p=dropout)]
  75. for _ in six.moves.range(1, self.dlayers):
  76. self.decoder += [
  77. torch.nn.LSTMCell(dunits, dunits)
  78. if self.dtype == "lstm"
  79. else torch.nn.GRUCell(dunits, dunits)
  80. ]
  81. self.dropout_dec += [torch.nn.Dropout(p=dropout)]
  82. # NOTE: dropout is applied only for the vertical connections
  83. # see https://arxiv.org/pdf/1409.2329.pdf
  84. self.ignore_id = -1
  85. if context_residual:
  86. self.output = torch.nn.Linear(dunits + eprojs, odim)
  87. else:
  88. self.output = torch.nn.Linear(dunits, odim)
  89. self.loss = None
  90. self.att = att
  91. self.dunits = dunits
  92. self.sos = sos
  93. self.eos = eos
  94. self.odim = odim
  95. self.verbose = verbose
  96. self.char_list = char_list
  97. # for label smoothing
  98. self.labeldist = labeldist
  99. self.vlabeldist = None
  100. self.lsm_weight = lsm_weight
  101. self.sampling_probability = sampling_probability
  102. self.dropout = dropout
  103. self.num_encs = num_encs
  104. # for multilingual E2E-ST
  105. self.replace_sos = replace_sos
  106. self.logzero = -10000000000.0
  107. def zero_state(self, hs_pad):
  108. return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
  109. def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
  110. if self.dtype == "lstm":
  111. z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
  112. for i in six.moves.range(1, self.dlayers):
  113. z_list[i], c_list[i] = self.decoder[i](
  114. self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i])
  115. )
  116. else:
  117. z_list[0] = self.decoder[0](ey, z_prev[0])
  118. for i in six.moves.range(1, self.dlayers):
  119. z_list[i] = self.decoder[i](
  120. self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
  121. )
  122. return z_list, c_list
  123. def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
  124. """Decoder forward
  125. :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
  126. [in multi-encoder case,
  127. list of torch.Tensor,
  128. [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
  129. :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
  130. [in multi-encoder case, list of torch.Tensor,
  131. [(B), (B), ..., ]
  132. :param torch.Tensor ys_pad: batch of padded character id sequence tensor
  133. (B, Lmax)
  134. :param int strm_idx: stream index indicates the index of decoding stream.
  135. :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
  136. :return: attention loss value
  137. :rtype: torch.Tensor
  138. :return: accuracy
  139. :rtype: float
  140. """
  141. # to support mutiple encoder asr mode, in single encoder mode,
  142. # convert torch.Tensor to List of torch.Tensor
  143. if self.num_encs == 1:
  144. hs_pad = [hs_pad]
  145. hlens = [hlens]
  146. # TODO(kan-bayashi): need to make more smart way
  147. ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
  148. # attention index for the attention module
  149. # in SPA (speaker parallel attention),
  150. # att_idx is used to select attention module. In other cases, it is 0.
  151. att_idx = min(strm_idx, len(self.att) - 1)
  152. # hlens should be list of list of integer
  153. hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
  154. self.loss = None
  155. # prepare input and output word sequences with sos/eos IDs
  156. eos = ys[0].new([self.eos])
  157. sos = ys[0].new([self.sos])
  158. if self.replace_sos:
  159. ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
  160. else:
  161. ys_in = [torch.cat([sos, y], dim=0) for y in ys]
  162. ys_out = [torch.cat([y, eos], dim=0) for y in ys]
  163. # padding for ys with -1
  164. # pys: utt x olen
  165. ys_in_pad = pad_list(ys_in, self.eos)
  166. ys_out_pad = pad_list(ys_out, self.ignore_id)
  167. # get dim, length info
  168. batch = ys_out_pad.size(0)
  169. olength = ys_out_pad.size(1)
  170. for idx in range(self.num_encs):
  171. logging.info(
  172. self.__class__.__name__
  173. + "Number of Encoder:{}; enc{}: input lengths: {}.".format(
  174. self.num_encs, idx + 1, hlens[idx]
  175. )
  176. )
  177. logging.info(
  178. self.__class__.__name__
  179. + " output lengths: "
  180. + str([y.size(0) for y in ys_out])
  181. )
  182. # initialization
  183. c_list = [self.zero_state(hs_pad[0])]
  184. z_list = [self.zero_state(hs_pad[0])]
  185. for _ in six.moves.range(1, self.dlayers):
  186. c_list.append(self.zero_state(hs_pad[0]))
  187. z_list.append(self.zero_state(hs_pad[0]))
  188. z_all = []
  189. if self.num_encs == 1:
  190. att_w = None
  191. self.att[att_idx].reset() # reset pre-computation of h
  192. else:
  193. att_w_list = [None] * (self.num_encs + 1) # atts + han
  194. att_c_list = [None] * (self.num_encs) # atts
  195. for idx in range(self.num_encs + 1):
  196. self.att[idx].reset() # reset pre-computation of h in atts and han
  197. # pre-computation of embedding
  198. eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
  199. # loop for an output sequence
  200. for i in six.moves.range(olength):
  201. if self.num_encs == 1:
  202. att_c, att_w = self.att[att_idx](
  203. hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
  204. )
  205. else:
  206. for idx in range(self.num_encs):
  207. att_c_list[idx], att_w_list[idx] = self.att[idx](
  208. hs_pad[idx],
  209. hlens[idx],
  210. self.dropout_dec[0](z_list[0]),
  211. att_w_list[idx],
  212. )
  213. hs_pad_han = torch.stack(att_c_list, dim=1)
  214. hlens_han = [self.num_encs] * len(ys_in)
  215. att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
  216. hs_pad_han,
  217. hlens_han,
  218. self.dropout_dec[0](z_list[0]),
  219. att_w_list[self.num_encs],
  220. )
  221. if i > 0 and random.random() < self.sampling_probability:
  222. logging.info(" scheduled sampling ")
  223. z_out = self.output(z_all[-1])
  224. z_out = np.argmax(z_out.detach().cpu(), axis=1)
  225. z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out)))
  226. ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
  227. else:
  228. ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
  229. z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
  230. if self.context_residual:
  231. z_all.append(
  232. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  233. ) # utt x (zdim + hdim)
  234. else:
  235. z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
  236. z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
  237. # compute loss
  238. y_all = self.output(z_all)
  239. self.loss = F.cross_entropy(
  240. y_all,
  241. ys_out_pad.view(-1),
  242. ignore_index=self.ignore_id,
  243. reduction="mean",
  244. )
  245. # compute perplexity
  246. ppl = math.exp(self.loss.item())
  247. # -1: eos, which is removed in the loss computation
  248. self.loss *= np.mean([len(x) for x in ys_in]) - 1
  249. acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
  250. logging.info("att loss:" + "".join(str(self.loss.item()).split("\n")))
  251. # show predicted character sequence for debug
  252. if self.verbose > 0 and self.char_list is not None:
  253. ys_hat = y_all.view(batch, olength, -1)
  254. ys_true = ys_out_pad
  255. for (i, y_hat), y_true in zip(
  256. enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
  257. ):
  258. if i == MAX_DECODER_OUTPUT:
  259. break
  260. idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
  261. idx_true = y_true[y_true != self.ignore_id]
  262. seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
  263. seq_true = [self.char_list[int(idx)] for idx in idx_true]
  264. seq_hat = "".join(seq_hat)
  265. seq_true = "".join(seq_true)
  266. logging.info("groundtruth[%d]: " % i + seq_true)
  267. logging.info("prediction [%d]: " % i + seq_hat)
  268. if self.labeldist is not None:
  269. if self.vlabeldist is None:
  270. self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist))
  271. loss_reg = -torch.sum(
  272. (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0
  273. ) / len(ys_in)
  274. self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
  275. return self.loss, acc, ppl
  276. def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
  277. """beam search implementation
  278. :param torch.Tensor h: encoder hidden state (T, eprojs)
  279. [in multi-encoder case, list of torch.Tensor,
  280. [(T1, eprojs), (T2, eprojs), ...] ]
  281. :param torch.Tensor lpz: ctc log softmax output (T, odim)
  282. [in multi-encoder case, list of torch.Tensor,
  283. [(T1, odim), (T2, odim), ...] ]
  284. :param Namespace recog_args: argument Namespace containing options
  285. :param char_list: list of character strings
  286. :param torch.nn.Module rnnlm: language module
  287. :param int strm_idx:
  288. stream index for speaker parallel attention in multi-speaker case
  289. :return: N-best decoding results
  290. :rtype: list of dicts
  291. """
  292. # to support mutiple encoder asr mode, in single encoder mode,
  293. # convert torch.Tensor to List of torch.Tensor
  294. if self.num_encs == 1:
  295. h = [h]
  296. lpz = [lpz]
  297. if self.num_encs > 1 and lpz is None:
  298. lpz = [lpz] * self.num_encs
  299. for idx in range(self.num_encs):
  300. logging.info(
  301. "Number of Encoder:{}; enc{}: input lengths: {}.".format(
  302. self.num_encs, idx + 1, h[0].size(0)
  303. )
  304. )
  305. att_idx = min(strm_idx, len(self.att) - 1)
  306. # initialization
  307. c_list = [self.zero_state(h[0].unsqueeze(0))]
  308. z_list = [self.zero_state(h[0].unsqueeze(0))]
  309. for _ in six.moves.range(1, self.dlayers):
  310. c_list.append(self.zero_state(h[0].unsqueeze(0)))
  311. z_list.append(self.zero_state(h[0].unsqueeze(0)))
  312. if self.num_encs == 1:
  313. a = None
  314. self.att[att_idx].reset() # reset pre-computation of h
  315. else:
  316. a = [None] * (self.num_encs + 1) # atts + han
  317. att_w_list = [None] * (self.num_encs + 1) # atts + han
  318. att_c_list = [None] * (self.num_encs) # atts
  319. for idx in range(self.num_encs + 1):
  320. self.att[idx].reset() # reset pre-computation of h in atts and han
  321. # search parms
  322. beam = recog_args.beam_size
  323. penalty = recog_args.penalty
  324. ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
  325. if lpz[0] is not None and self.num_encs > 1:
  326. # weights-ctc,
  327. # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
  328. weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
  329. recog_args.weights_ctc_dec
  330. ) # normalize
  331. logging.info(
  332. "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
  333. )
  334. else:
  335. weights_ctc_dec = [1.0]
  336. # preprate sos
  337. if self.replace_sos and recog_args.tgt_lang:
  338. y = char_list.index(recog_args.tgt_lang)
  339. else:
  340. y = self.sos
  341. logging.info("<sos> index: " + str(y))
  342. logging.info("<sos> mark: " + char_list[y])
  343. vy = h[0].new_zeros(1).long()
  344. maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
  345. if recog_args.maxlenratio != 0:
  346. # maxlen >= 1
  347. maxlen = max(1, int(recog_args.maxlenratio * maxlen))
  348. minlen = int(recog_args.minlenratio * maxlen)
  349. logging.info("max output length: " + str(maxlen))
  350. logging.info("min output length: " + str(minlen))
  351. # initialize hypothesis
  352. if rnnlm:
  353. hyp = {
  354. "score": 0.0,
  355. "yseq": [y],
  356. "c_prev": c_list,
  357. "z_prev": z_list,
  358. "a_prev": a,
  359. "rnnlm_prev": None,
  360. }
  361. else:
  362. hyp = {
  363. "score": 0.0,
  364. "yseq": [y],
  365. "c_prev": c_list,
  366. "z_prev": z_list,
  367. "a_prev": a,
  368. }
  369. if lpz[0] is not None:
  370. ctc_prefix_score = [
  371. CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np)
  372. for idx in range(self.num_encs)
  373. ]
  374. hyp["ctc_state_prev"] = [
  375. ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)
  376. ]
  377. hyp["ctc_score_prev"] = [0.0] * self.num_encs
  378. if ctc_weight != 1.0:
  379. # pre-pruning based on attention scores
  380. ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
  381. else:
  382. ctc_beam = lpz[0].shape[-1]
  383. hyps = [hyp]
  384. ended_hyps = []
  385. for i in six.moves.range(maxlen):
  386. logging.debug("position " + str(i))
  387. hyps_best_kept = []
  388. for hyp in hyps:
  389. vy[0] = hyp["yseq"][i]
  390. ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
  391. if self.num_encs == 1:
  392. att_c, att_w = self.att[att_idx](
  393. h[0].unsqueeze(0),
  394. [h[0].size(0)],
  395. self.dropout_dec[0](hyp["z_prev"][0]),
  396. hyp["a_prev"],
  397. )
  398. else:
  399. for idx in range(self.num_encs):
  400. att_c_list[idx], att_w_list[idx] = self.att[idx](
  401. h[idx].unsqueeze(0),
  402. [h[idx].size(0)],
  403. self.dropout_dec[0](hyp["z_prev"][0]),
  404. hyp["a_prev"][idx],
  405. )
  406. h_han = torch.stack(att_c_list, dim=1)
  407. att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
  408. h_han,
  409. [self.num_encs],
  410. self.dropout_dec[0](hyp["z_prev"][0]),
  411. hyp["a_prev"][self.num_encs],
  412. )
  413. ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
  414. z_list, c_list = self.rnn_forward(
  415. ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
  416. )
  417. # get nbest local scores and their ids
  418. if self.context_residual:
  419. logits = self.output(
  420. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  421. )
  422. else:
  423. logits = self.output(self.dropout_dec[-1](z_list[-1]))
  424. local_att_scores = F.log_softmax(logits, dim=1)
  425. if rnnlm:
  426. rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
  427. local_scores = (
  428. local_att_scores + recog_args.lm_weight * local_lm_scores
  429. )
  430. else:
  431. local_scores = local_att_scores
  432. if lpz[0] is not None:
  433. local_best_scores, local_best_ids = torch.topk(
  434. local_att_scores, ctc_beam, dim=1
  435. )
  436. ctc_scores, ctc_states = (
  437. [None] * self.num_encs,
  438. [None] * self.num_encs,
  439. )
  440. for idx in range(self.num_encs):
  441. ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
  442. hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
  443. )
  444. local_scores = (1.0 - ctc_weight) * local_att_scores[
  445. :, local_best_ids[0]
  446. ]
  447. if self.num_encs == 1:
  448. local_scores += ctc_weight * torch.from_numpy(
  449. ctc_scores[0] - hyp["ctc_score_prev"][0]
  450. )
  451. else:
  452. for idx in range(self.num_encs):
  453. local_scores += (
  454. ctc_weight
  455. * weights_ctc_dec[idx]
  456. * torch.from_numpy(
  457. ctc_scores[idx] - hyp["ctc_score_prev"][idx]
  458. )
  459. )
  460. if rnnlm:
  461. local_scores += (
  462. recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
  463. )
  464. local_best_scores, joint_best_ids = torch.topk(
  465. local_scores, beam, dim=1
  466. )
  467. local_best_ids = local_best_ids[:, joint_best_ids[0]]
  468. else:
  469. local_best_scores, local_best_ids = torch.topk(
  470. local_scores, beam, dim=1
  471. )
  472. for j in six.moves.range(beam):
  473. new_hyp = {}
  474. # [:] is needed!
  475. new_hyp["z_prev"] = z_list[:]
  476. new_hyp["c_prev"] = c_list[:]
  477. if self.num_encs == 1:
  478. new_hyp["a_prev"] = att_w[:]
  479. else:
  480. new_hyp["a_prev"] = [
  481. att_w_list[idx][:] for idx in range(self.num_encs + 1)
  482. ]
  483. new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
  484. new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
  485. new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
  486. new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j])
  487. if rnnlm:
  488. new_hyp["rnnlm_prev"] = rnnlm_state
  489. if lpz[0] is not None:
  490. new_hyp["ctc_state_prev"] = [
  491. ctc_states[idx][joint_best_ids[0, j]]
  492. for idx in range(self.num_encs)
  493. ]
  494. new_hyp["ctc_score_prev"] = [
  495. ctc_scores[idx][joint_best_ids[0, j]]
  496. for idx in range(self.num_encs)
  497. ]
  498. # will be (2 x beam) hyps at most
  499. hyps_best_kept.append(new_hyp)
  500. hyps_best_kept = sorted(
  501. hyps_best_kept, key=lambda x: x["score"], reverse=True
  502. )[:beam]
  503. # sort and get nbest
  504. hyps = hyps_best_kept
  505. logging.debug("number of pruned hypotheses: " + str(len(hyps)))
  506. logging.debug(
  507. "best hypo: "
  508. + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
  509. )
  510. # add eos in the final loop to avoid that there are no ended hyps
  511. if i == maxlen - 1:
  512. logging.info("adding <eos> in the last position in the loop")
  513. for hyp in hyps:
  514. hyp["yseq"].append(self.eos)
  515. # add ended hypotheses to a final list,
  516. # and removed them from current hypotheses
  517. # (this will be a problem, number of hyps < beam)
  518. remained_hyps = []
  519. for hyp in hyps:
  520. if hyp["yseq"][-1] == self.eos:
  521. # only store the sequence that has more than minlen outputs
  522. # also add penalty
  523. if len(hyp["yseq"]) > minlen:
  524. hyp["score"] += (i + 1) * penalty
  525. if rnnlm: # Word LM needs to add final <eos> score
  526. hyp["score"] += recog_args.lm_weight * rnnlm.final(
  527. hyp["rnnlm_prev"]
  528. )
  529. ended_hyps.append(hyp)
  530. else:
  531. remained_hyps.append(hyp)
  532. # end detection
  533. if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
  534. logging.info("end detected at %d", i)
  535. break
  536. hyps = remained_hyps
  537. if len(hyps) > 0:
  538. logging.debug("remaining hypotheses: " + str(len(hyps)))
  539. else:
  540. logging.info("no hypothesis. Finish decoding.")
  541. break
  542. for hyp in hyps:
  543. logging.debug(
  544. "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
  545. )
  546. logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
  547. nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
  548. : min(len(ended_hyps), recog_args.nbest)
  549. ]
  550. # check number of hypotheses
  551. if len(nbest_hyps) == 0:
  552. logging.warning(
  553. "there is no N-best results, "
  554. "perform recognition again with smaller minlenratio."
  555. )
  556. # should copy because Namespace will be overwritten globally
  557. recog_args = Namespace(**vars(recog_args))
  558. recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
  559. if self.num_encs == 1:
  560. return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
  561. else:
  562. return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
  563. logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
  564. logging.info(
  565. "normalized log probability: "
  566. + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))
  567. )
  568. # remove sos
  569. return nbest_hyps
  570. def recognize_beam_batch(
  571. self,
  572. h,
  573. hlens,
  574. lpz,
  575. recog_args,
  576. char_list,
  577. rnnlm=None,
  578. normalize_score=True,
  579. strm_idx=0,
  580. lang_ids=None,
  581. ):
  582. # to support mutiple encoder asr mode, in single encoder mode,
  583. # convert torch.Tensor to List of torch.Tensor
  584. if self.num_encs == 1:
  585. h = [h]
  586. hlens = [hlens]
  587. lpz = [lpz]
  588. if self.num_encs > 1 and lpz is None:
  589. lpz = [lpz] * self.num_encs
  590. att_idx = min(strm_idx, len(self.att) - 1)
  591. for idx in range(self.num_encs):
  592. logging.info(
  593. "Number of Encoder:{}; enc{}: input lengths: {}.".format(
  594. self.num_encs, idx + 1, h[idx].size(1)
  595. )
  596. )
  597. h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
  598. # search params
  599. batch = len(hlens[0])
  600. beam = recog_args.beam_size
  601. penalty = recog_args.penalty
  602. ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
  603. att_weight = 1.0 - ctc_weight
  604. ctc_margin = getattr(
  605. recog_args, "ctc_window_margin", 0
  606. ) # use getattr to keep compatibility
  607. # weights-ctc,
  608. # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
  609. if lpz[0] is not None and self.num_encs > 1:
  610. weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
  611. recog_args.weights_ctc_dec
  612. ) # normalize
  613. logging.info(
  614. "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
  615. )
  616. else:
  617. weights_ctc_dec = [1.0]
  618. n_bb = batch * beam
  619. pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1)
  620. max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
  621. if recog_args.maxlenratio == 0:
  622. maxlen = max_hlen
  623. else:
  624. maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
  625. minlen = int(recog_args.minlenratio * max_hlen)
  626. logging.info("max output length: " + str(maxlen))
  627. logging.info("min output length: " + str(minlen))
  628. # initialization
  629. c_prev = [
  630. to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
  631. ]
  632. z_prev = [
  633. to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
  634. ]
  635. c_list = [
  636. to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
  637. ]
  638. z_list = [
  639. to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
  640. ]
  641. vscores = to_device(h[0], torch.zeros(batch, beam))
  642. rnnlm_state = None
  643. if self.num_encs == 1:
  644. a_prev = [None]
  645. att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
  646. self.att[att_idx].reset() # reset pre-computation of h
  647. else:
  648. a_prev = [None] * (self.num_encs + 1) # atts + han
  649. att_w_list = [None] * (self.num_encs + 1) # atts + han
  650. att_c_list = [None] * (self.num_encs) # atts
  651. ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
  652. for idx in range(self.num_encs + 1):
  653. self.att[idx].reset() # reset pre-computation of h in atts and han
  654. if self.replace_sos and recog_args.tgt_lang:
  655. logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
  656. logging.info("<sos> mark: " + recog_args.tgt_lang)
  657. yseq = [
  658. [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)
  659. ]
  660. elif lang_ids is not None:
  661. # NOTE: used for evaluation during training
  662. yseq = [
  663. [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)
  664. ]
  665. else:
  666. logging.info("<sos> index: " + str(self.sos))
  667. logging.info("<sos> mark: " + char_list[self.sos])
  668. yseq = [[self.sos] for _ in six.moves.range(n_bb)]
  669. accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
  670. stop_search = [False for _ in six.moves.range(batch)]
  671. nbest_hyps = [[] for _ in six.moves.range(batch)]
  672. ended_hyps = [[] for _ in range(batch)]
  673. exp_hlens = [
  674. hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous()
  675. for idx in range(self.num_encs)
  676. ]
  677. exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
  678. exp_h = [
  679. h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
  680. for idx in range(self.num_encs)
  681. ]
  682. exp_h = [
  683. exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
  684. for idx in range(self.num_encs)
  685. ]
  686. if lpz[0] is not None:
  687. scoring_num = min(
  688. int(beam * CTC_SCORING_RATIO)
  689. if att_weight > 0.0 and not lpz[0].is_cuda
  690. else 0,
  691. lpz[0].size(-1),
  692. )
  693. ctc_scorer = [
  694. CTCPrefixScoreTH(
  695. lpz[idx],
  696. hlens[idx],
  697. 0,
  698. self.eos,
  699. margin=ctc_margin,
  700. )
  701. for idx in range(self.num_encs)
  702. ]
  703. for i in six.moves.range(maxlen):
  704. logging.debug("position " + str(i))
  705. vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq)))
  706. ey = self.dropout_emb(self.embed(vy))
  707. if self.num_encs == 1:
  708. att_c, att_w = self.att[att_idx](
  709. exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0]
  710. )
  711. att_w_list = [att_w]
  712. else:
  713. for idx in range(self.num_encs):
  714. att_c_list[idx], att_w_list[idx] = self.att[idx](
  715. exp_h[idx],
  716. exp_hlens[idx],
  717. self.dropout_dec[0](z_prev[0]),
  718. a_prev[idx],
  719. )
  720. exp_h_han = torch.stack(att_c_list, dim=1)
  721. att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
  722. exp_h_han,
  723. [self.num_encs] * n_bb,
  724. self.dropout_dec[0](z_prev[0]),
  725. a_prev[self.num_encs],
  726. )
  727. ey = torch.cat((ey, att_c), dim=1)
  728. # attention decoder
  729. z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
  730. if self.context_residual:
  731. logits = self.output(
  732. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  733. )
  734. else:
  735. logits = self.output(self.dropout_dec[-1](z_list[-1]))
  736. local_scores = att_weight * F.log_softmax(logits, dim=1)
  737. # rnnlm
  738. if rnnlm:
  739. rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
  740. local_scores = local_scores + recog_args.lm_weight * local_lm_scores
  741. # ctc
  742. if ctc_scorer[0]:
  743. local_scores[:, 0] = self.logzero # avoid choosing blank
  744. part_ids = (
  745. torch.topk(local_scores, scoring_num, dim=-1)[1]
  746. if scoring_num > 0
  747. else None
  748. )
  749. for idx in range(self.num_encs):
  750. att_w = att_w_list[idx]
  751. att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
  752. local_ctc_scores, ctc_state[idx] = ctc_scorer[idx](
  753. yseq, ctc_state[idx], part_ids, att_w_
  754. )
  755. local_scores = (
  756. local_scores
  757. + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
  758. )
  759. local_scores = local_scores.view(batch, beam, self.odim)
  760. if i == 0:
  761. local_scores[:, 1:, :] = self.logzero
  762. # accumulate scores
  763. eos_vscores = local_scores[:, :, self.eos] + vscores
  764. vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
  765. vscores[:, :, self.eos] = self.logzero
  766. vscores = (vscores + local_scores).view(batch, -1)
  767. # global pruning
  768. accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
  769. accum_odim_ids = (
  770. torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
  771. )
  772. accum_padded_beam_ids = (
  773. (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
  774. )
  775. y_prev = yseq[:][:]
  776. yseq = self._index_select_list(yseq, accum_padded_beam_ids)
  777. yseq = self._append_ids(yseq, accum_odim_ids)
  778. vscores = accum_best_scores
  779. vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids))
  780. a_prev = []
  781. num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
  782. for idx in range(num_atts):
  783. if isinstance(att_w_list[idx], torch.Tensor):
  784. _a_prev = torch.index_select(
  785. att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx
  786. )
  787. elif isinstance(att_w_list[idx], list):
  788. # handle the case of multi-head attention
  789. _a_prev = [
  790. torch.index_select(att_w_one.view(n_bb, -1), 0, vidx)
  791. for att_w_one in att_w_list[idx]
  792. ]
  793. else:
  794. # handle the case of location_recurrent when return is a tuple
  795. _a_prev_ = torch.index_select(
  796. att_w_list[idx][0].view(n_bb, -1), 0, vidx
  797. )
  798. _h_prev_ = torch.index_select(
  799. att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
  800. )
  801. _c_prev_ = torch.index_select(
  802. att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
  803. )
  804. _a_prev = (_a_prev_, (_h_prev_, _c_prev_))
  805. a_prev.append(_a_prev)
  806. z_prev = [
  807. torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
  808. for li in range(self.dlayers)
  809. ]
  810. c_prev = [
  811. torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
  812. for li in range(self.dlayers)
  813. ]
  814. # pick ended hyps
  815. if i >= minlen:
  816. k = 0
  817. penalty_i = (i + 1) * penalty
  818. thr = accum_best_scores[:, -1]
  819. for samp_i in six.moves.range(batch):
  820. if stop_search[samp_i]:
  821. k = k + beam
  822. continue
  823. for beam_j in six.moves.range(beam):
  824. _vscore = None
  825. if eos_vscores[samp_i, beam_j] > thr[samp_i]:
  826. yk = y_prev[k][:]
  827. if len(yk) <= min(
  828. hlens[idx][samp_i] for idx in range(self.num_encs)
  829. ):
  830. _vscore = eos_vscores[samp_i][beam_j] + penalty_i
  831. elif i == maxlen - 1:
  832. yk = yseq[k][:]
  833. _vscore = vscores[samp_i][beam_j] + penalty_i
  834. if _vscore:
  835. yk.append(self.eos)
  836. if rnnlm:
  837. _vscore += recog_args.lm_weight * rnnlm.final(
  838. rnnlm_state, index=k
  839. )
  840. _score = _vscore.data.cpu().numpy()
  841. ended_hyps[samp_i].append(
  842. {"yseq": yk, "vscore": _vscore, "score": _score}
  843. )
  844. k = k + 1
  845. # end detection
  846. stop_search = [
  847. stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
  848. for samp_i in six.moves.range(batch)
  849. ]
  850. stop_search_summary = list(set(stop_search))
  851. if len(stop_search_summary) == 1 and stop_search_summary[0]:
  852. break
  853. if rnnlm:
  854. rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
  855. if ctc_scorer[0]:
  856. for idx in range(self.num_encs):
  857. ctc_state[idx] = ctc_scorer[idx].index_select_state(
  858. ctc_state[idx], accum_best_ids
  859. )
  860. torch.cuda.empty_cache()
  861. dummy_hyps = [
  862. {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
  863. ]
  864. ended_hyps = [
  865. ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
  866. for samp_i in six.moves.range(batch)
  867. ]
  868. if normalize_score:
  869. for samp_i in six.moves.range(batch):
  870. for x in ended_hyps[samp_i]:
  871. x["score"] /= len(x["yseq"])
  872. nbest_hyps = [
  873. sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
  874. : min(len(ended_hyps[samp_i]), recog_args.nbest)
  875. ]
  876. for samp_i in six.moves.range(batch)
  877. ]
  878. return nbest_hyps
  879. def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
  880. """Calculate all of attentions
  881. :param torch.Tensor hs_pad: batch of padded hidden state sequences
  882. (B, Tmax, D)
  883. in multi-encoder case, list of torch.Tensor,
  884. [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
  885. :param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
  886. [in multi-encoder case, list of torch.Tensor,
  887. [(B), (B), ..., ]
  888. :param torch.Tensor ys_pad:
  889. batch of padded character id sequence tensor (B, Lmax)
  890. :param int strm_idx:
  891. stream index for parallel speaker attention in multi-speaker case
  892. :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
  893. :return: attention weights with the following shape,
  894. 1) multi-head case => attention weights (B, H, Lmax, Tmax),
  895. 2) multi-encoder case =>
  896. [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
  897. 3) other case => attention weights (B, Lmax, Tmax).
  898. :rtype: float ndarray
  899. """
  900. # to support mutiple encoder asr mode, in single encoder mode,
  901. # convert torch.Tensor to List of torch.Tensor
  902. if self.num_encs == 1:
  903. hs_pad = [hs_pad]
  904. hlen = [hlen]
  905. # TODO(kan-bayashi): need to make more smart way
  906. ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
  907. att_idx = min(strm_idx, len(self.att) - 1)
  908. # hlen should be list of list of integer
  909. hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
  910. self.loss = None
  911. # prepare input and output word sequences with sos/eos IDs
  912. eos = ys[0].new([self.eos])
  913. sos = ys[0].new([self.sos])
  914. if self.replace_sos:
  915. ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
  916. else:
  917. ys_in = [torch.cat([sos, y], dim=0) for y in ys]
  918. ys_out = [torch.cat([y, eos], dim=0) for y in ys]
  919. # padding for ys with -1
  920. # pys: utt x olen
  921. ys_in_pad = pad_list(ys_in, self.eos)
  922. ys_out_pad = pad_list(ys_out, self.ignore_id)
  923. # get length info
  924. olength = ys_out_pad.size(1)
  925. # initialization
  926. c_list = [self.zero_state(hs_pad[0])]
  927. z_list = [self.zero_state(hs_pad[0])]
  928. for _ in six.moves.range(1, self.dlayers):
  929. c_list.append(self.zero_state(hs_pad[0]))
  930. z_list.append(self.zero_state(hs_pad[0]))
  931. att_ws = []
  932. if self.num_encs == 1:
  933. att_w = None
  934. self.att[att_idx].reset() # reset pre-computation of h
  935. else:
  936. att_w_list = [None] * (self.num_encs + 1) # atts + han
  937. att_c_list = [None] * (self.num_encs) # atts
  938. for idx in range(self.num_encs + 1):
  939. self.att[idx].reset() # reset pre-computation of h in atts and han
  940. # pre-computation of embedding
  941. eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
  942. # loop for an output sequence
  943. for i in six.moves.range(olength):
  944. if self.num_encs == 1:
  945. att_c, att_w = self.att[att_idx](
  946. hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w
  947. )
  948. att_ws.append(att_w)
  949. else:
  950. for idx in range(self.num_encs):
  951. att_c_list[idx], att_w_list[idx] = self.att[idx](
  952. hs_pad[idx],
  953. hlen[idx],
  954. self.dropout_dec[0](z_list[0]),
  955. att_w_list[idx],
  956. )
  957. hs_pad_han = torch.stack(att_c_list, dim=1)
  958. hlen_han = [self.num_encs] * len(ys_in)
  959. att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
  960. hs_pad_han,
  961. hlen_han,
  962. self.dropout_dec[0](z_list[0]),
  963. att_w_list[self.num_encs],
  964. )
  965. att_ws.append(att_w_list.copy())
  966. ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
  967. z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
  968. if self.num_encs == 1:
  969. # convert to numpy array with the shape (B, Lmax, Tmax)
  970. att_ws = att_to_numpy(att_ws, self.att[att_idx])
  971. else:
  972. _att_ws = []
  973. for idx, ws in enumerate(zip(*att_ws)):
  974. ws = att_to_numpy(ws, self.att[idx])
  975. _att_ws.append(ws)
  976. att_ws = _att_ws
  977. return att_ws
  978. @staticmethod
  979. def _get_last_yseq(exp_yseq):
  980. last = []
  981. for y_seq in exp_yseq:
  982. last.append(y_seq[-1])
  983. return last
  984. @staticmethod
  985. def _append_ids(yseq, ids):
  986. if isinstance(ids, list):
  987. for i, j in enumerate(ids):
  988. yseq[i].append(j)
  989. else:
  990. for i in range(len(yseq)):
  991. yseq[i].append(ids)
  992. return yseq
  993. @staticmethod
  994. def _index_select_list(yseq, lst):
  995. new_yseq = []
  996. for i in lst:
  997. new_yseq.append(yseq[i][:])
  998. return new_yseq
  999. @staticmethod
  1000. def _index_select_lm_state(rnnlm_state, dim, vidx):
  1001. if isinstance(rnnlm_state, dict):
  1002. new_state = {}
  1003. for k, v in rnnlm_state.items():
  1004. new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
  1005. elif isinstance(rnnlm_state, list):
  1006. new_state = []
  1007. for i in vidx:
  1008. new_state.append(rnnlm_state[int(i)][:])
  1009. return new_state
  1010. # scorer interface methods
  1011. def init_state(self, x):
  1012. # to support mutiple encoder asr mode, in single encoder mode,
  1013. # convert torch.Tensor to List of torch.Tensor
  1014. if self.num_encs == 1:
  1015. x = [x]
  1016. c_list = [self.zero_state(x[0].unsqueeze(0))]
  1017. z_list = [self.zero_state(x[0].unsqueeze(0))]
  1018. for _ in six.moves.range(1, self.dlayers):
  1019. c_list.append(self.zero_state(x[0].unsqueeze(0)))
  1020. z_list.append(self.zero_state(x[0].unsqueeze(0)))
  1021. # TODO(karita): support strm_index for `asr_mix`
  1022. strm_index = 0
  1023. att_idx = min(strm_index, len(self.att) - 1)
  1024. if self.num_encs == 1:
  1025. a = None
  1026. self.att[att_idx].reset() # reset pre-computation of h
  1027. else:
  1028. a = [None] * (self.num_encs + 1) # atts + han
  1029. for idx in range(self.num_encs + 1):
  1030. self.att[idx].reset() # reset pre-computation of h in atts and han
  1031. return dict(
  1032. c_prev=c_list[:],
  1033. z_prev=z_list[:],
  1034. a_prev=a,
  1035. workspace=(att_idx, z_list, c_list),
  1036. )
  1037. def score(self, yseq, state, x):
  1038. # to support mutiple encoder asr mode, in single encoder mode,
  1039. # convert torch.Tensor to List of torch.Tensor
  1040. if self.num_encs == 1:
  1041. x = [x]
  1042. att_idx, z_list, c_list = state["workspace"]
  1043. vy = yseq[-1].unsqueeze(0)
  1044. ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
  1045. if self.num_encs == 1:
  1046. att_c, att_w = self.att[att_idx](
  1047. x[0].unsqueeze(0),
  1048. [x[0].size(0)],
  1049. self.dropout_dec[0](state["z_prev"][0]),
  1050. state["a_prev"],
  1051. )
  1052. else:
  1053. att_w = [None] * (self.num_encs + 1) # atts + han
  1054. att_c_list = [None] * (self.num_encs) # atts
  1055. for idx in range(self.num_encs):
  1056. att_c_list[idx], att_w[idx] = self.att[idx](
  1057. x[idx].unsqueeze(0),
  1058. [x[idx].size(0)],
  1059. self.dropout_dec[0](state["z_prev"][0]),
  1060. state["a_prev"][idx],
  1061. )
  1062. h_han = torch.stack(att_c_list, dim=1)
  1063. att_c, att_w[self.num_encs] = self.att[self.num_encs](
  1064. h_han,
  1065. [self.num_encs],
  1066. self.dropout_dec[0](state["z_prev"][0]),
  1067. state["a_prev"][self.num_encs],
  1068. )
  1069. ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
  1070. z_list, c_list = self.rnn_forward(
  1071. ey, z_list, c_list, state["z_prev"], state["c_prev"]
  1072. )
  1073. if self.context_residual:
  1074. logits = self.output(
  1075. torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
  1076. )
  1077. else:
  1078. logits = self.output(self.dropout_dec[-1](z_list[-1]))
  1079. logp = F.log_softmax(logits, dim=1).squeeze(0)
  1080. return (
  1081. logp,
  1082. dict(
  1083. c_prev=c_list[:],
  1084. z_prev=z_list[:],
  1085. a_prev=att_w,
  1086. workspace=(att_idx, z_list, c_list),
  1087. ),
  1088. )
  1089. def decoder_for(args, odim, sos, eos, att, labeldist):
  1090. return Decoder(
  1091. args.eprojs,
  1092. odim,
  1093. args.dtype,
  1094. args.dlayers,
  1095. args.dunits,
  1096. sos,
  1097. eos,
  1098. att,
  1099. args.verbose,
  1100. args.char_list,
  1101. labeldist,
  1102. args.lsm_weight,
  1103. args.sampling_probability,
  1104. args.dropout_rate_decoder,
  1105. getattr(args, "context_residual", False), # use getattr to keep compatibility
  1106. getattr(args, "replace_sos", False), # use getattr to keep compatibility
  1107. getattr(args, "num_encs", 1),
  1108. ) # use getattr to keep compatibility