transformer_decoder.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Decoder definition."""
  4. from typing import Any
  5. from typing import List
  6. from typing import Sequence
  7. from typing import Tuple
  8. import torch
  9. from torch import nn
  10. from funasr.models.decoder.abs_decoder import AbsDecoder
  11. from funasr.modules.attention import MultiHeadedAttention
  12. from funasr.modules.attention import CosineDistanceAttention
  13. from funasr.modules.dynamic_conv import DynamicConvolution
  14. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  15. from funasr.modules.embedding import PositionalEncoding
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.lightconv import LightweightConvolution
  18. from funasr.modules.lightconv2d import LightweightConvolution2D
  19. from funasr.modules.mask import subsequent_mask
  20. from funasr.modules.nets_utils import make_pad_mask
  21. from funasr.modules.positionwise_feed_forward import (
  22. PositionwiseFeedForward, # noqa: H301
  23. )
  24. from funasr.modules.repeat import repeat
  25. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  26. class DecoderLayer(nn.Module):
  27. """Single decoder layer module.
  28. Args:
  29. size (int): Input dimension.
  30. self_attn (torch.nn.Module): Self-attention module instance.
  31. `MultiHeadedAttention` instance can be used as the argument.
  32. src_attn (torch.nn.Module): Self-attention module instance.
  33. `MultiHeadedAttention` instance can be used as the argument.
  34. feed_forward (torch.nn.Module): Feed-forward module instance.
  35. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  36. can be used as the argument.
  37. dropout_rate (float): Dropout rate.
  38. normalize_before (bool): Whether to use layer_norm before the first block.
  39. concat_after (bool): Whether to concat attention layer's input and output.
  40. if True, additional linear will be applied.
  41. i.e. x -> x + linear(concat(x, att(x)))
  42. if False, no additional linear will be applied. i.e. x -> x + att(x)
  43. """
  44. def __init__(
  45. self,
  46. size,
  47. self_attn,
  48. src_attn,
  49. feed_forward,
  50. dropout_rate,
  51. normalize_before=True,
  52. concat_after=False,
  53. ):
  54. """Construct an DecoderLayer object."""
  55. super(DecoderLayer, self).__init__()
  56. self.size = size
  57. self.self_attn = self_attn
  58. self.src_attn = src_attn
  59. self.feed_forward = feed_forward
  60. self.norm1 = LayerNorm(size)
  61. self.norm2 = LayerNorm(size)
  62. self.norm3 = LayerNorm(size)
  63. self.dropout = nn.Dropout(dropout_rate)
  64. self.normalize_before = normalize_before
  65. self.concat_after = concat_after
  66. if self.concat_after:
  67. self.concat_linear1 = nn.Linear(size + size, size)
  68. self.concat_linear2 = nn.Linear(size + size, size)
  69. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  70. """Compute decoded features.
  71. Args:
  72. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  73. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  74. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  75. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  76. cache (List[torch.Tensor]): List of cached tensors.
  77. Each tensor shape should be (#batch, maxlen_out - 1, size).
  78. Returns:
  79. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  80. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  81. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  82. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  83. """
  84. residual = tgt
  85. if self.normalize_before:
  86. tgt = self.norm1(tgt)
  87. if cache is None:
  88. tgt_q = tgt
  89. tgt_q_mask = tgt_mask
  90. else:
  91. # compute only the last frame query keeping dim: max_time_out -> 1
  92. assert cache.shape == (
  93. tgt.shape[0],
  94. tgt.shape[1] - 1,
  95. self.size,
  96. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  97. tgt_q = tgt[:, -1:, :]
  98. residual = residual[:, -1:, :]
  99. tgt_q_mask = None
  100. if tgt_mask is not None:
  101. tgt_q_mask = tgt_mask[:, -1:, :]
  102. if self.concat_after:
  103. tgt_concat = torch.cat(
  104. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  105. )
  106. x = residual + self.concat_linear1(tgt_concat)
  107. else:
  108. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  109. if not self.normalize_before:
  110. x = self.norm1(x)
  111. residual = x
  112. if self.normalize_before:
  113. x = self.norm2(x)
  114. if self.concat_after:
  115. x_concat = torch.cat(
  116. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  117. )
  118. x = residual + self.concat_linear2(x_concat)
  119. else:
  120. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  121. if not self.normalize_before:
  122. x = self.norm2(x)
  123. residual = x
  124. if self.normalize_before:
  125. x = self.norm3(x)
  126. x = residual + self.dropout(self.feed_forward(x))
  127. if not self.normalize_before:
  128. x = self.norm3(x)
  129. if cache is not None:
  130. x = torch.cat([cache, x], dim=1)
  131. return x, tgt_mask, memory, memory_mask
  132. class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
  133. """Base class of Transfomer decoder module.
  134. Args:
  135. vocab_size: output dim
  136. encoder_output_size: dimension of attention
  137. attention_heads: the number of heads of multi head attention
  138. linear_units: the number of units of position-wise feed forward
  139. num_blocks: the number of decoder blocks
  140. dropout_rate: dropout rate
  141. self_attention_dropout_rate: dropout rate for attention
  142. input_layer: input layer type
  143. use_output_layer: whether to use output layer
  144. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  145. normalize_before: whether to use layer_norm before the first block
  146. concat_after: whether to concat attention layer's input and output
  147. if True, additional linear will be applied.
  148. i.e. x -> x + linear(concat(x, att(x)))
  149. if False, no additional linear will be applied.
  150. i.e. x -> x + att(x)
  151. """
  152. def __init__(
  153. self,
  154. vocab_size: int,
  155. encoder_output_size: int,
  156. dropout_rate: float = 0.1,
  157. positional_dropout_rate: float = 0.1,
  158. input_layer: str = "embed",
  159. use_output_layer: bool = True,
  160. pos_enc_class=PositionalEncoding,
  161. normalize_before: bool = True,
  162. ):
  163. super().__init__()
  164. attention_dim = encoder_output_size
  165. if input_layer == "embed":
  166. self.embed = torch.nn.Sequential(
  167. torch.nn.Embedding(vocab_size, attention_dim),
  168. pos_enc_class(attention_dim, positional_dropout_rate),
  169. )
  170. elif input_layer == "linear":
  171. self.embed = torch.nn.Sequential(
  172. torch.nn.Linear(vocab_size, attention_dim),
  173. torch.nn.LayerNorm(attention_dim),
  174. torch.nn.Dropout(dropout_rate),
  175. torch.nn.ReLU(),
  176. pos_enc_class(attention_dim, positional_dropout_rate),
  177. )
  178. else:
  179. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  180. self.normalize_before = normalize_before
  181. if self.normalize_before:
  182. self.after_norm = LayerNorm(attention_dim)
  183. if use_output_layer:
  184. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  185. else:
  186. self.output_layer = None
  187. # Must set by the inheritance
  188. self.decoders = None
  189. def forward(
  190. self,
  191. hs_pad: torch.Tensor,
  192. hlens: torch.Tensor,
  193. ys_in_pad: torch.Tensor,
  194. ys_in_lens: torch.Tensor,
  195. ) -> Tuple[torch.Tensor, torch.Tensor]:
  196. """Forward decoder.
  197. Args:
  198. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  199. hlens: (batch)
  200. ys_in_pad:
  201. input token ids, int64 (batch, maxlen_out)
  202. if input_layer == "embed"
  203. input tensor (batch, maxlen_out, #mels) in the other cases
  204. ys_in_lens: (batch)
  205. Returns:
  206. (tuple): tuple containing:
  207. x: decoded token score before softmax (batch, maxlen_out, token)
  208. if use_output_layer is True,
  209. olens: (batch, )
  210. """
  211. tgt = ys_in_pad
  212. # tgt_mask: (B, 1, L)
  213. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  214. # m: (1, L, L)
  215. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  216. # tgt_mask: (B, L, L)
  217. tgt_mask = tgt_mask & m
  218. memory = hs_pad
  219. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  220. memory.device
  221. )
  222. # Padding for Longformer
  223. if memory_mask.shape[-1] != memory.shape[1]:
  224. padlen = memory.shape[1] - memory_mask.shape[-1]
  225. memory_mask = torch.nn.functional.pad(
  226. memory_mask, (0, padlen), "constant", False
  227. )
  228. x = self.embed(tgt)
  229. x, tgt_mask, memory, memory_mask = self.decoders(
  230. x, tgt_mask, memory, memory_mask
  231. )
  232. if self.normalize_before:
  233. x = self.after_norm(x)
  234. if self.output_layer is not None:
  235. x = self.output_layer(x)
  236. olens = tgt_mask.sum(1)
  237. return x, olens
  238. def forward_one_step(
  239. self,
  240. tgt: torch.Tensor,
  241. tgt_mask: torch.Tensor,
  242. memory: torch.Tensor,
  243. cache: List[torch.Tensor] = None,
  244. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  245. """Forward one step.
  246. Args:
  247. tgt: input token ids, int64 (batch, maxlen_out)
  248. tgt_mask: input token mask, (batch, maxlen_out)
  249. dtype=torch.uint8 in PyTorch 1.2-
  250. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  251. memory: encoded memory, float32 (batch, maxlen_in, feat)
  252. cache: cached output list of (batch, max_time_out-1, size)
  253. Returns:
  254. y, cache: NN output value and cache per `self.decoders`.
  255. y.shape` is (batch, maxlen_out, token)
  256. """
  257. x = self.embed(tgt)
  258. if cache is None:
  259. cache = [None] * len(self.decoders)
  260. new_cache = []
  261. for c, decoder in zip(cache, self.decoders):
  262. x, tgt_mask, memory, memory_mask = decoder(
  263. x, tgt_mask, memory, None, cache=c
  264. )
  265. new_cache.append(x)
  266. if self.normalize_before:
  267. y = self.after_norm(x[:, -1])
  268. else:
  269. y = x[:, -1]
  270. if self.output_layer is not None:
  271. y = torch.log_softmax(self.output_layer(y), dim=-1)
  272. return y, new_cache
  273. def score(self, ys, state, x):
  274. """Score."""
  275. ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
  276. logp, state = self.forward_one_step(
  277. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  278. )
  279. return logp.squeeze(0), state
  280. def batch_score(
  281. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  282. ) -> Tuple[torch.Tensor, List[Any]]:
  283. """Score new token batch.
  284. Args:
  285. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  286. states (List[Any]): Scorer states for prefix tokens.
  287. xs (torch.Tensor):
  288. The encoder feature that generates ys (n_batch, xlen, n_feat).
  289. Returns:
  290. tuple[torch.Tensor, List[Any]]: Tuple of
  291. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  292. and next state list for ys.
  293. """
  294. # merge states
  295. n_batch = len(ys)
  296. n_layers = len(self.decoders)
  297. if states[0] is None:
  298. batch_state = None
  299. else:
  300. # transpose state of [batch, layer] into [layer, batch]
  301. batch_state = [
  302. torch.stack([states[b][i] for b in range(n_batch)])
  303. for i in range(n_layers)
  304. ]
  305. # batch decoding
  306. ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
  307. logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
  308. # transpose state of [layer, batch] into [batch, layer]
  309. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  310. return logp, state_list
  311. class TransformerDecoder(BaseTransformerDecoder):
  312. def __init__(
  313. self,
  314. vocab_size: int,
  315. encoder_output_size: int,
  316. attention_heads: int = 4,
  317. linear_units: int = 2048,
  318. num_blocks: int = 6,
  319. dropout_rate: float = 0.1,
  320. positional_dropout_rate: float = 0.1,
  321. self_attention_dropout_rate: float = 0.0,
  322. src_attention_dropout_rate: float = 0.0,
  323. input_layer: str = "embed",
  324. use_output_layer: bool = True,
  325. pos_enc_class=PositionalEncoding,
  326. normalize_before: bool = True,
  327. concat_after: bool = False,
  328. ):
  329. super().__init__(
  330. vocab_size=vocab_size,
  331. encoder_output_size=encoder_output_size,
  332. dropout_rate=dropout_rate,
  333. positional_dropout_rate=positional_dropout_rate,
  334. input_layer=input_layer,
  335. use_output_layer=use_output_layer,
  336. pos_enc_class=pos_enc_class,
  337. normalize_before=normalize_before,
  338. )
  339. attention_dim = encoder_output_size
  340. self.decoders = repeat(
  341. num_blocks,
  342. lambda lnum: DecoderLayer(
  343. attention_dim,
  344. MultiHeadedAttention(
  345. attention_heads, attention_dim, self_attention_dropout_rate
  346. ),
  347. MultiHeadedAttention(
  348. attention_heads, attention_dim, src_attention_dropout_rate
  349. ),
  350. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  351. dropout_rate,
  352. normalize_before,
  353. concat_after,
  354. ),
  355. )
  356. class ParaformerDecoderSAN(BaseTransformerDecoder):
  357. """
  358. Author: Speech Lab of DAMO Academy, Alibaba Group
  359. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  360. https://arxiv.org/abs/2006.01713
  361. """
  362. def __init__(
  363. self,
  364. vocab_size: int,
  365. encoder_output_size: int,
  366. attention_heads: int = 4,
  367. linear_units: int = 2048,
  368. num_blocks: int = 6,
  369. dropout_rate: float = 0.1,
  370. positional_dropout_rate: float = 0.1,
  371. self_attention_dropout_rate: float = 0.0,
  372. src_attention_dropout_rate: float = 0.0,
  373. input_layer: str = "embed",
  374. use_output_layer: bool = True,
  375. pos_enc_class=PositionalEncoding,
  376. normalize_before: bool = True,
  377. concat_after: bool = False,
  378. embeds_id: int = -1,
  379. ):
  380. super().__init__(
  381. vocab_size=vocab_size,
  382. encoder_output_size=encoder_output_size,
  383. dropout_rate=dropout_rate,
  384. positional_dropout_rate=positional_dropout_rate,
  385. input_layer=input_layer,
  386. use_output_layer=use_output_layer,
  387. pos_enc_class=pos_enc_class,
  388. normalize_before=normalize_before,
  389. )
  390. attention_dim = encoder_output_size
  391. self.decoders = repeat(
  392. num_blocks,
  393. lambda lnum: DecoderLayer(
  394. attention_dim,
  395. MultiHeadedAttention(
  396. attention_heads, attention_dim, self_attention_dropout_rate
  397. ),
  398. MultiHeadedAttention(
  399. attention_heads, attention_dim, src_attention_dropout_rate
  400. ),
  401. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  402. dropout_rate,
  403. normalize_before,
  404. concat_after,
  405. ),
  406. )
  407. self.embeds_id = embeds_id
  408. self.attention_dim = attention_dim
  409. def forward(
  410. self,
  411. hs_pad: torch.Tensor,
  412. hlens: torch.Tensor,
  413. ys_in_pad: torch.Tensor,
  414. ys_in_lens: torch.Tensor,
  415. ) -> Tuple[torch.Tensor, torch.Tensor]:
  416. """Forward decoder.
  417. Args:
  418. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  419. hlens: (batch)
  420. ys_in_pad:
  421. input token ids, int64 (batch, maxlen_out)
  422. if input_layer == "embed"
  423. input tensor (batch, maxlen_out, #mels) in the other cases
  424. ys_in_lens: (batch)
  425. Returns:
  426. (tuple): tuple containing:
  427. x: decoded token score before softmax (batch, maxlen_out, token)
  428. if use_output_layer is True,
  429. olens: (batch, )
  430. """
  431. tgt = ys_in_pad
  432. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  433. memory = hs_pad
  434. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  435. memory.device
  436. )
  437. # Padding for Longformer
  438. if memory_mask.shape[-1] != memory.shape[1]:
  439. padlen = memory.shape[1] - memory_mask.shape[-1]
  440. memory_mask = torch.nn.functional.pad(
  441. memory_mask, (0, padlen), "constant", False
  442. )
  443. # x = self.embed(tgt)
  444. x = tgt
  445. embeds_outputs = None
  446. for layer_id, decoder in enumerate(self.decoders):
  447. x, tgt_mask, memory, memory_mask = decoder(
  448. x, tgt_mask, memory, memory_mask
  449. )
  450. if layer_id == self.embeds_id:
  451. embeds_outputs = x
  452. if self.normalize_before:
  453. x = self.after_norm(x)
  454. if self.output_layer is not None:
  455. x = self.output_layer(x)
  456. olens = tgt_mask.sum(1)
  457. if embeds_outputs is not None:
  458. return x, olens, embeds_outputs
  459. else:
  460. return x, olens
  461. class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
  462. def __init__(
  463. self,
  464. vocab_size: int,
  465. encoder_output_size: int,
  466. attention_heads: int = 4,
  467. linear_units: int = 2048,
  468. num_blocks: int = 6,
  469. dropout_rate: float = 0.1,
  470. positional_dropout_rate: float = 0.1,
  471. self_attention_dropout_rate: float = 0.0,
  472. src_attention_dropout_rate: float = 0.0,
  473. input_layer: str = "embed",
  474. use_output_layer: bool = True,
  475. pos_enc_class=PositionalEncoding,
  476. normalize_before: bool = True,
  477. concat_after: bool = False,
  478. conv_wshare: int = 4,
  479. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  480. conv_usebias: int = False,
  481. ):
  482. if len(conv_kernel_length) != num_blocks:
  483. raise ValueError(
  484. "conv_kernel_length must have equal number of values to num_blocks: "
  485. f"{len(conv_kernel_length)} != {num_blocks}"
  486. )
  487. super().__init__(
  488. vocab_size=vocab_size,
  489. encoder_output_size=encoder_output_size,
  490. dropout_rate=dropout_rate,
  491. positional_dropout_rate=positional_dropout_rate,
  492. input_layer=input_layer,
  493. use_output_layer=use_output_layer,
  494. pos_enc_class=pos_enc_class,
  495. normalize_before=normalize_before,
  496. )
  497. attention_dim = encoder_output_size
  498. self.decoders = repeat(
  499. num_blocks,
  500. lambda lnum: DecoderLayer(
  501. attention_dim,
  502. LightweightConvolution(
  503. wshare=conv_wshare,
  504. n_feat=attention_dim,
  505. dropout_rate=self_attention_dropout_rate,
  506. kernel_size=conv_kernel_length[lnum],
  507. use_kernel_mask=True,
  508. use_bias=conv_usebias,
  509. ),
  510. MultiHeadedAttention(
  511. attention_heads, attention_dim, src_attention_dropout_rate
  512. ),
  513. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  514. dropout_rate,
  515. normalize_before,
  516. concat_after,
  517. ),
  518. )
  519. class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
  520. def __init__(
  521. self,
  522. vocab_size: int,
  523. encoder_output_size: int,
  524. attention_heads: int = 4,
  525. linear_units: int = 2048,
  526. num_blocks: int = 6,
  527. dropout_rate: float = 0.1,
  528. positional_dropout_rate: float = 0.1,
  529. self_attention_dropout_rate: float = 0.0,
  530. src_attention_dropout_rate: float = 0.0,
  531. input_layer: str = "embed",
  532. use_output_layer: bool = True,
  533. pos_enc_class=PositionalEncoding,
  534. normalize_before: bool = True,
  535. concat_after: bool = False,
  536. conv_wshare: int = 4,
  537. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  538. conv_usebias: int = False,
  539. ):
  540. if len(conv_kernel_length) != num_blocks:
  541. raise ValueError(
  542. "conv_kernel_length must have equal number of values to num_blocks: "
  543. f"{len(conv_kernel_length)} != {num_blocks}"
  544. )
  545. super().__init__(
  546. vocab_size=vocab_size,
  547. encoder_output_size=encoder_output_size,
  548. dropout_rate=dropout_rate,
  549. positional_dropout_rate=positional_dropout_rate,
  550. input_layer=input_layer,
  551. use_output_layer=use_output_layer,
  552. pos_enc_class=pos_enc_class,
  553. normalize_before=normalize_before,
  554. )
  555. attention_dim = encoder_output_size
  556. self.decoders = repeat(
  557. num_blocks,
  558. lambda lnum: DecoderLayer(
  559. attention_dim,
  560. LightweightConvolution2D(
  561. wshare=conv_wshare,
  562. n_feat=attention_dim,
  563. dropout_rate=self_attention_dropout_rate,
  564. kernel_size=conv_kernel_length[lnum],
  565. use_kernel_mask=True,
  566. use_bias=conv_usebias,
  567. ),
  568. MultiHeadedAttention(
  569. attention_heads, attention_dim, src_attention_dropout_rate
  570. ),
  571. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  572. dropout_rate,
  573. normalize_before,
  574. concat_after,
  575. ),
  576. )
  577. class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
  578. def __init__(
  579. self,
  580. vocab_size: int,
  581. encoder_output_size: int,
  582. attention_heads: int = 4,
  583. linear_units: int = 2048,
  584. num_blocks: int = 6,
  585. dropout_rate: float = 0.1,
  586. positional_dropout_rate: float = 0.1,
  587. self_attention_dropout_rate: float = 0.0,
  588. src_attention_dropout_rate: float = 0.0,
  589. input_layer: str = "embed",
  590. use_output_layer: bool = True,
  591. pos_enc_class=PositionalEncoding,
  592. normalize_before: bool = True,
  593. concat_after: bool = False,
  594. conv_wshare: int = 4,
  595. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  596. conv_usebias: int = False,
  597. ):
  598. if len(conv_kernel_length) != num_blocks:
  599. raise ValueError(
  600. "conv_kernel_length must have equal number of values to num_blocks: "
  601. f"{len(conv_kernel_length)} != {num_blocks}"
  602. )
  603. super().__init__(
  604. vocab_size=vocab_size,
  605. encoder_output_size=encoder_output_size,
  606. dropout_rate=dropout_rate,
  607. positional_dropout_rate=positional_dropout_rate,
  608. input_layer=input_layer,
  609. use_output_layer=use_output_layer,
  610. pos_enc_class=pos_enc_class,
  611. normalize_before=normalize_before,
  612. )
  613. attention_dim = encoder_output_size
  614. self.decoders = repeat(
  615. num_blocks,
  616. lambda lnum: DecoderLayer(
  617. attention_dim,
  618. DynamicConvolution(
  619. wshare=conv_wshare,
  620. n_feat=attention_dim,
  621. dropout_rate=self_attention_dropout_rate,
  622. kernel_size=conv_kernel_length[lnum],
  623. use_kernel_mask=True,
  624. use_bias=conv_usebias,
  625. ),
  626. MultiHeadedAttention(
  627. attention_heads, attention_dim, src_attention_dropout_rate
  628. ),
  629. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  630. dropout_rate,
  631. normalize_before,
  632. concat_after,
  633. ),
  634. )
  635. class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
  636. def __init__(
  637. self,
  638. vocab_size: int,
  639. encoder_output_size: int,
  640. attention_heads: int = 4,
  641. linear_units: int = 2048,
  642. num_blocks: int = 6,
  643. dropout_rate: float = 0.1,
  644. positional_dropout_rate: float = 0.1,
  645. self_attention_dropout_rate: float = 0.0,
  646. src_attention_dropout_rate: float = 0.0,
  647. input_layer: str = "embed",
  648. use_output_layer: bool = True,
  649. pos_enc_class=PositionalEncoding,
  650. normalize_before: bool = True,
  651. concat_after: bool = False,
  652. conv_wshare: int = 4,
  653. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  654. conv_usebias: int = False,
  655. ):
  656. if len(conv_kernel_length) != num_blocks:
  657. raise ValueError(
  658. "conv_kernel_length must have equal number of values to num_blocks: "
  659. f"{len(conv_kernel_length)} != {num_blocks}"
  660. )
  661. super().__init__(
  662. vocab_size=vocab_size,
  663. encoder_output_size=encoder_output_size,
  664. dropout_rate=dropout_rate,
  665. positional_dropout_rate=positional_dropout_rate,
  666. input_layer=input_layer,
  667. use_output_layer=use_output_layer,
  668. pos_enc_class=pos_enc_class,
  669. normalize_before=normalize_before,
  670. )
  671. attention_dim = encoder_output_size
  672. self.decoders = repeat(
  673. num_blocks,
  674. lambda lnum: DecoderLayer(
  675. attention_dim,
  676. DynamicConvolution2D(
  677. wshare=conv_wshare,
  678. n_feat=attention_dim,
  679. dropout_rate=self_attention_dropout_rate,
  680. kernel_size=conv_kernel_length[lnum],
  681. use_kernel_mask=True,
  682. use_bias=conv_usebias,
  683. ),
  684. MultiHeadedAttention(
  685. attention_heads, attention_dim, src_attention_dropout_rate
  686. ),
  687. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  688. dropout_rate,
  689. normalize_before,
  690. concat_after,
  691. ),
  692. )
  693. class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
  694. def __init__(
  695. self,
  696. vocab_size: int,
  697. encoder_output_size: int,
  698. spker_embedding_dim: int = 256,
  699. dropout_rate: float = 0.1,
  700. positional_dropout_rate: float = 0.1,
  701. input_layer: str = "embed",
  702. use_asr_output_layer: bool = True,
  703. use_spk_output_layer: bool = True,
  704. pos_enc_class=PositionalEncoding,
  705. normalize_before: bool = True,
  706. ):
  707. super().__init__()
  708. attention_dim = encoder_output_size
  709. if input_layer == "embed":
  710. self.embed = torch.nn.Sequential(
  711. torch.nn.Embedding(vocab_size, attention_dim),
  712. pos_enc_class(attention_dim, positional_dropout_rate),
  713. )
  714. elif input_layer == "linear":
  715. self.embed = torch.nn.Sequential(
  716. torch.nn.Linear(vocab_size, attention_dim),
  717. torch.nn.LayerNorm(attention_dim),
  718. torch.nn.Dropout(dropout_rate),
  719. torch.nn.ReLU(),
  720. pos_enc_class(attention_dim, positional_dropout_rate),
  721. )
  722. else:
  723. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  724. self.normalize_before = normalize_before
  725. if self.normalize_before:
  726. self.after_norm = LayerNorm(attention_dim)
  727. if use_asr_output_layer:
  728. self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
  729. else:
  730. self.asr_output_layer = None
  731. if use_spk_output_layer:
  732. self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
  733. else:
  734. self.spk_output_layer = None
  735. self.cos_distance_att = CosineDistanceAttention()
  736. self.decoder1 = None
  737. self.decoder2 = None
  738. self.decoder3 = None
  739. self.decoder4 = None
  740. def forward(
  741. self,
  742. asr_hs_pad: torch.Tensor,
  743. spk_hs_pad: torch.Tensor,
  744. hlens: torch.Tensor,
  745. ys_in_pad: torch.Tensor,
  746. ys_in_lens: torch.Tensor,
  747. profile: torch.Tensor,
  748. profile_lens: torch.Tensor,
  749. ) -> Tuple[torch.Tensor, torch.Tensor]:
  750. tgt = ys_in_pad
  751. # tgt_mask: (B, 1, L)
  752. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  753. # m: (1, L, L)
  754. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  755. # tgt_mask: (B, L, L)
  756. tgt_mask = tgt_mask & m
  757. asr_memory = asr_hs_pad
  758. spk_memory = spk_hs_pad
  759. memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
  760. # Spk decoder
  761. x = self.embed(tgt)
  762. x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
  763. x, tgt_mask, asr_memory, spk_memory, memory_mask
  764. )
  765. x, tgt_mask, spk_memory, memory_mask = self.decoder2(
  766. x, tgt_mask, spk_memory, memory_mask
  767. )
  768. if self.normalize_before:
  769. x = self.after_norm(x)
  770. if self.spk_output_layer is not None:
  771. x = self.spk_output_layer(x)
  772. dn, weights = self.cos_distance_att(x, profile, profile_lens)
  773. # Asr decoder
  774. x, tgt_mask, asr_memory, memory_mask = self.decoder3(
  775. z, tgt_mask, asr_memory, memory_mask, dn
  776. )
  777. x, tgt_mask, asr_memory, memory_mask = self.decoder4(
  778. x, tgt_mask, asr_memory, memory_mask
  779. )
  780. if self.normalize_before:
  781. x = self.after_norm(x)
  782. if self.asr_output_layer is not None:
  783. x = self.asr_output_layer(x)
  784. olens = tgt_mask.sum(1)
  785. return x, weights, olens
  786. def forward_one_step(
  787. self,
  788. tgt: torch.Tensor,
  789. tgt_mask: torch.Tensor,
  790. asr_memory: torch.Tensor,
  791. spk_memory: torch.Tensor,
  792. profile: torch.Tensor,
  793. cache: List[torch.Tensor] = None,
  794. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  795. x = self.embed(tgt)
  796. if cache is None:
  797. cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
  798. new_cache = []
  799. x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
  800. x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
  801. )
  802. new_cache.append(x)
  803. for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
  804. x, tgt_mask, spk_memory, _ = decoder(
  805. x, tgt_mask, spk_memory, None, cache=c
  806. )
  807. new_cache.append(x)
  808. if self.normalize_before:
  809. x = self.after_norm(x)
  810. else:
  811. x = x
  812. if self.spk_output_layer is not None:
  813. x = self.spk_output_layer(x)
  814. dn, weights = self.cos_distance_att(x, profile, None)
  815. x, tgt_mask, asr_memory, _ = self.decoder3(
  816. z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
  817. )
  818. new_cache.append(x)
  819. for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
  820. x, tgt_mask, asr_memory, _ = decoder(
  821. x, tgt_mask, asr_memory, None, cache=c
  822. )
  823. new_cache.append(x)
  824. if self.normalize_before:
  825. y = self.after_norm(x[:, -1])
  826. else:
  827. y = x[:, -1]
  828. if self.asr_output_layer is not None:
  829. y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
  830. return y, weights, new_cache
  831. def score(self, ys, state, asr_enc, spk_enc, profile):
  832. """Score."""
  833. ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
  834. logp, weights, state = self.forward_one_step(
  835. ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
  836. )
  837. return logp.squeeze(0), weights.squeeze(), state
  838. class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
  839. def __init__(
  840. self,
  841. vocab_size: int,
  842. encoder_output_size: int,
  843. spker_embedding_dim: int = 256,
  844. attention_heads: int = 4,
  845. linear_units: int = 2048,
  846. asr_num_blocks: int = 6,
  847. spk_num_blocks: int = 3,
  848. dropout_rate: float = 0.1,
  849. positional_dropout_rate: float = 0.1,
  850. self_attention_dropout_rate: float = 0.0,
  851. src_attention_dropout_rate: float = 0.0,
  852. input_layer: str = "embed",
  853. use_asr_output_layer: bool = True,
  854. use_spk_output_layer: bool = True,
  855. pos_enc_class=PositionalEncoding,
  856. normalize_before: bool = True,
  857. concat_after: bool = False,
  858. ):
  859. super().__init__(
  860. vocab_size=vocab_size,
  861. encoder_output_size=encoder_output_size,
  862. spker_embedding_dim=spker_embedding_dim,
  863. dropout_rate=dropout_rate,
  864. positional_dropout_rate=positional_dropout_rate,
  865. input_layer=input_layer,
  866. use_asr_output_layer=use_asr_output_layer,
  867. use_spk_output_layer=use_spk_output_layer,
  868. pos_enc_class=pos_enc_class,
  869. normalize_before=normalize_before,
  870. )
  871. attention_dim = encoder_output_size
  872. self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
  873. attention_dim,
  874. MultiHeadedAttention(
  875. attention_heads, attention_dim, self_attention_dropout_rate
  876. ),
  877. MultiHeadedAttention(
  878. attention_heads, attention_dim, src_attention_dropout_rate
  879. ),
  880. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  881. dropout_rate,
  882. normalize_before,
  883. concat_after,
  884. )
  885. self.decoder2 = repeat(
  886. spk_num_blocks - 1,
  887. lambda lnum: DecoderLayer(
  888. attention_dim,
  889. MultiHeadedAttention(
  890. attention_heads, attention_dim, self_attention_dropout_rate
  891. ),
  892. MultiHeadedAttention(
  893. attention_heads, attention_dim, src_attention_dropout_rate
  894. ),
  895. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  896. dropout_rate,
  897. normalize_before,
  898. concat_after,
  899. ),
  900. )
  901. self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
  902. attention_dim,
  903. spker_embedding_dim,
  904. MultiHeadedAttention(
  905. attention_heads, attention_dim, src_attention_dropout_rate
  906. ),
  907. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  908. dropout_rate,
  909. normalize_before,
  910. concat_after,
  911. )
  912. self.decoder4 = repeat(
  913. asr_num_blocks - 1,
  914. lambda lnum: DecoderLayer(
  915. attention_dim,
  916. MultiHeadedAttention(
  917. attention_heads, attention_dim, self_attention_dropout_rate
  918. ),
  919. MultiHeadedAttention(
  920. attention_heads, attention_dim, src_attention_dropout_rate
  921. ),
  922. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  923. dropout_rate,
  924. normalize_before,
  925. concat_after,
  926. ),
  927. )
  928. class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
  929. def __init__(
  930. self,
  931. size,
  932. self_attn,
  933. src_attn,
  934. feed_forward,
  935. dropout_rate,
  936. normalize_before=True,
  937. concat_after=False,
  938. ):
  939. """Construct an DecoderLayer object."""
  940. super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
  941. self.size = size
  942. self.self_attn = self_attn
  943. self.src_attn = src_attn
  944. self.feed_forward = feed_forward
  945. self.norm1 = LayerNorm(size)
  946. self.norm2 = LayerNorm(size)
  947. self.dropout = nn.Dropout(dropout_rate)
  948. self.normalize_before = normalize_before
  949. self.concat_after = concat_after
  950. if self.concat_after:
  951. self.concat_linear1 = nn.Linear(size + size, size)
  952. self.concat_linear2 = nn.Linear(size + size, size)
  953. def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
  954. residual = tgt
  955. if self.normalize_before:
  956. tgt = self.norm1(tgt)
  957. if cache is None:
  958. tgt_q = tgt
  959. tgt_q_mask = tgt_mask
  960. else:
  961. # compute only the last frame query keeping dim: max_time_out -> 1
  962. assert cache.shape == (
  963. tgt.shape[0],
  964. tgt.shape[1] - 1,
  965. self.size,
  966. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  967. tgt_q = tgt[:, -1:, :]
  968. residual = residual[:, -1:, :]
  969. tgt_q_mask = None
  970. if tgt_mask is not None:
  971. tgt_q_mask = tgt_mask[:, -1:, :]
  972. if self.concat_after:
  973. tgt_concat = torch.cat(
  974. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  975. )
  976. x = residual + self.concat_linear1(tgt_concat)
  977. else:
  978. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  979. if not self.normalize_before:
  980. x = self.norm1(x)
  981. z = x
  982. residual = x
  983. if self.normalize_before:
  984. x = self.norm1(x)
  985. skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
  986. if self.concat_after:
  987. x_concat = torch.cat(
  988. (x, skip), dim=-1
  989. )
  990. x = residual + self.concat_linear2(x_concat)
  991. else:
  992. x = residual + self.dropout(skip)
  993. if not self.normalize_before:
  994. x = self.norm1(x)
  995. residual = x
  996. if self.normalize_before:
  997. x = self.norm2(x)
  998. x = residual + self.dropout(self.feed_forward(x))
  999. if not self.normalize_before:
  1000. x = self.norm2(x)
  1001. if cache is not None:
  1002. x = torch.cat([cache, x], dim=1)
  1003. return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
  1004. class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
  1005. def __init__(
  1006. self,
  1007. size,
  1008. d_size,
  1009. src_attn,
  1010. feed_forward,
  1011. dropout_rate,
  1012. normalize_before=True,
  1013. concat_after=False,
  1014. ):
  1015. """Construct an DecoderLayer object."""
  1016. super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
  1017. self.size = size
  1018. self.src_attn = src_attn
  1019. self.feed_forward = feed_forward
  1020. self.norm1 = LayerNorm(size)
  1021. self.norm2 = LayerNorm(size)
  1022. self.norm3 = LayerNorm(size)
  1023. self.dropout = nn.Dropout(dropout_rate)
  1024. self.normalize_before = normalize_before
  1025. self.concat_after = concat_after
  1026. self.spk_linear = nn.Linear(d_size, size, bias=False)
  1027. if self.concat_after:
  1028. self.concat_linear1 = nn.Linear(size + size, size)
  1029. self.concat_linear2 = nn.Linear(size + size, size)
  1030. def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
  1031. residual = tgt
  1032. if self.normalize_before:
  1033. tgt = self.norm1(tgt)
  1034. if cache is None:
  1035. tgt_q = tgt
  1036. tgt_q_mask = tgt_mask
  1037. else:
  1038. tgt_q = tgt[:, -1:, :]
  1039. residual = residual[:, -1:, :]
  1040. tgt_q_mask = None
  1041. if tgt_mask is not None:
  1042. tgt_q_mask = tgt_mask[:, -1:, :]
  1043. x = tgt_q
  1044. if self.normalize_before:
  1045. x = self.norm2(x)
  1046. if self.concat_after:
  1047. x_concat = torch.cat(
  1048. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  1049. )
  1050. x = residual + self.concat_linear2(x_concat)
  1051. else:
  1052. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  1053. if not self.normalize_before:
  1054. x = self.norm2(x)
  1055. residual = x
  1056. if dn!=None:
  1057. x = x + self.spk_linear(dn)
  1058. if self.normalize_before:
  1059. x = self.norm3(x)
  1060. x = residual + self.dropout(self.feed_forward(x))
  1061. if not self.normalize_before:
  1062. x = self.norm3(x)
  1063. if cache is not None:
  1064. x = torch.cat([cache, x], dim=1)
  1065. return x, tgt_mask, memory, memory_mask