transformer_decoder.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Decoder definition."""
  4. from typing import Any
  5. from typing import List
  6. from typing import Sequence
  7. from typing import Tuple
  8. import torch
  9. from torch import nn
  10. from typeguard import check_argument_types
  11. from funasr.models.decoder.abs_decoder import AbsDecoder
  12. from funasr.modules.attention import MultiHeadedAttention
  13. from funasr.modules.attention import CosineDistanceAttention
  14. from funasr.modules.dynamic_conv import DynamicConvolution
  15. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  16. from funasr.modules.embedding import PositionalEncoding
  17. from funasr.modules.layer_norm import LayerNorm
  18. from funasr.modules.lightconv import LightweightConvolution
  19. from funasr.modules.lightconv2d import LightweightConvolution2D
  20. from funasr.modules.mask import subsequent_mask
  21. from funasr.modules.nets_utils import make_pad_mask
  22. from funasr.modules.positionwise_feed_forward import (
  23. PositionwiseFeedForward, # noqa: H301
  24. )
  25. from funasr.modules.repeat import repeat
  26. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  27. class DecoderLayer(nn.Module):
  28. """Single decoder layer module.
  29. Args:
  30. size (int): Input dimension.
  31. self_attn (torch.nn.Module): Self-attention module instance.
  32. `MultiHeadedAttention` instance can be used as the argument.
  33. src_attn (torch.nn.Module): Self-attention module instance.
  34. `MultiHeadedAttention` instance can be used as the argument.
  35. feed_forward (torch.nn.Module): Feed-forward module instance.
  36. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  37. can be used as the argument.
  38. dropout_rate (float): Dropout rate.
  39. normalize_before (bool): Whether to use layer_norm before the first block.
  40. concat_after (bool): Whether to concat attention layer's input and output.
  41. if True, additional linear will be applied.
  42. i.e. x -> x + linear(concat(x, att(x)))
  43. if False, no additional linear will be applied. i.e. x -> x + att(x)
  44. """
  45. def __init__(
  46. self,
  47. size,
  48. self_attn,
  49. src_attn,
  50. feed_forward,
  51. dropout_rate,
  52. normalize_before=True,
  53. concat_after=False,
  54. ):
  55. """Construct an DecoderLayer object."""
  56. super(DecoderLayer, self).__init__()
  57. self.size = size
  58. self.self_attn = self_attn
  59. self.src_attn = src_attn
  60. self.feed_forward = feed_forward
  61. self.norm1 = LayerNorm(size)
  62. self.norm2 = LayerNorm(size)
  63. self.norm3 = LayerNorm(size)
  64. self.dropout = nn.Dropout(dropout_rate)
  65. self.normalize_before = normalize_before
  66. self.concat_after = concat_after
  67. if self.concat_after:
  68. self.concat_linear1 = nn.Linear(size + size, size)
  69. self.concat_linear2 = nn.Linear(size + size, size)
  70. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  71. """Compute decoded features.
  72. Args:
  73. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  74. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  75. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  76. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  77. cache (List[torch.Tensor]): List of cached tensors.
  78. Each tensor shape should be (#batch, maxlen_out - 1, size).
  79. Returns:
  80. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  81. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  82. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  83. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  84. """
  85. residual = tgt
  86. if self.normalize_before:
  87. tgt = self.norm1(tgt)
  88. if cache is None:
  89. tgt_q = tgt
  90. tgt_q_mask = tgt_mask
  91. else:
  92. # compute only the last frame query keeping dim: max_time_out -> 1
  93. assert cache.shape == (
  94. tgt.shape[0],
  95. tgt.shape[1] - 1,
  96. self.size,
  97. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  98. tgt_q = tgt[:, -1:, :]
  99. residual = residual[:, -1:, :]
  100. tgt_q_mask = None
  101. if tgt_mask is not None:
  102. tgt_q_mask = tgt_mask[:, -1:, :]
  103. if self.concat_after:
  104. tgt_concat = torch.cat(
  105. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  106. )
  107. x = residual + self.concat_linear1(tgt_concat)
  108. else:
  109. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  110. if not self.normalize_before:
  111. x = self.norm1(x)
  112. residual = x
  113. if self.normalize_before:
  114. x = self.norm2(x)
  115. if self.concat_after:
  116. x_concat = torch.cat(
  117. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  118. )
  119. x = residual + self.concat_linear2(x_concat)
  120. else:
  121. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  122. if not self.normalize_before:
  123. x = self.norm2(x)
  124. residual = x
  125. if self.normalize_before:
  126. x = self.norm3(x)
  127. x = residual + self.dropout(self.feed_forward(x))
  128. if not self.normalize_before:
  129. x = self.norm3(x)
  130. if cache is not None:
  131. x = torch.cat([cache, x], dim=1)
  132. return x, tgt_mask, memory, memory_mask
  133. class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
  134. """Base class of Transfomer decoder module.
  135. Args:
  136. vocab_size: output dim
  137. encoder_output_size: dimension of attention
  138. attention_heads: the number of heads of multi head attention
  139. linear_units: the number of units of position-wise feed forward
  140. num_blocks: the number of decoder blocks
  141. dropout_rate: dropout rate
  142. self_attention_dropout_rate: dropout rate for attention
  143. input_layer: input layer type
  144. use_output_layer: whether to use output layer
  145. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  146. normalize_before: whether to use layer_norm before the first block
  147. concat_after: whether to concat attention layer's input and output
  148. if True, additional linear will be applied.
  149. i.e. x -> x + linear(concat(x, att(x)))
  150. if False, no additional linear will be applied.
  151. i.e. x -> x + att(x)
  152. """
  153. def __init__(
  154. self,
  155. vocab_size: int,
  156. encoder_output_size: int,
  157. dropout_rate: float = 0.1,
  158. positional_dropout_rate: float = 0.1,
  159. input_layer: str = "embed",
  160. use_output_layer: bool = True,
  161. pos_enc_class=PositionalEncoding,
  162. normalize_before: bool = True,
  163. ):
  164. assert check_argument_types()
  165. super().__init__()
  166. attention_dim = encoder_output_size
  167. if input_layer == "embed":
  168. self.embed = torch.nn.Sequential(
  169. torch.nn.Embedding(vocab_size, attention_dim),
  170. pos_enc_class(attention_dim, positional_dropout_rate),
  171. )
  172. elif input_layer == "linear":
  173. self.embed = torch.nn.Sequential(
  174. torch.nn.Linear(vocab_size, attention_dim),
  175. torch.nn.LayerNorm(attention_dim),
  176. torch.nn.Dropout(dropout_rate),
  177. torch.nn.ReLU(),
  178. pos_enc_class(attention_dim, positional_dropout_rate),
  179. )
  180. else:
  181. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  182. self.normalize_before = normalize_before
  183. if self.normalize_before:
  184. self.after_norm = LayerNorm(attention_dim)
  185. if use_output_layer:
  186. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  187. else:
  188. self.output_layer = None
  189. # Must set by the inheritance
  190. self.decoders = None
  191. def forward(
  192. self,
  193. hs_pad: torch.Tensor,
  194. hlens: torch.Tensor,
  195. ys_in_pad: torch.Tensor,
  196. ys_in_lens: torch.Tensor,
  197. ) -> Tuple[torch.Tensor, torch.Tensor]:
  198. """Forward decoder.
  199. Args:
  200. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  201. hlens: (batch)
  202. ys_in_pad:
  203. input token ids, int64 (batch, maxlen_out)
  204. if input_layer == "embed"
  205. input tensor (batch, maxlen_out, #mels) in the other cases
  206. ys_in_lens: (batch)
  207. Returns:
  208. (tuple): tuple containing:
  209. x: decoded token score before softmax (batch, maxlen_out, token)
  210. if use_output_layer is True,
  211. olens: (batch, )
  212. """
  213. tgt = ys_in_pad
  214. # tgt_mask: (B, 1, L)
  215. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  216. # m: (1, L, L)
  217. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  218. # tgt_mask: (B, L, L)
  219. tgt_mask = tgt_mask & m
  220. memory = hs_pad
  221. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  222. memory.device
  223. )
  224. # Padding for Longformer
  225. if memory_mask.shape[-1] != memory.shape[1]:
  226. padlen = memory.shape[1] - memory_mask.shape[-1]
  227. memory_mask = torch.nn.functional.pad(
  228. memory_mask, (0, padlen), "constant", False
  229. )
  230. x = self.embed(tgt)
  231. x, tgt_mask, memory, memory_mask = self.decoders(
  232. x, tgt_mask, memory, memory_mask
  233. )
  234. if self.normalize_before:
  235. x = self.after_norm(x)
  236. if self.output_layer is not None:
  237. x = self.output_layer(x)
  238. olens = tgt_mask.sum(1)
  239. return x, olens
  240. def forward_one_step(
  241. self,
  242. tgt: torch.Tensor,
  243. tgt_mask: torch.Tensor,
  244. memory: torch.Tensor,
  245. cache: List[torch.Tensor] = None,
  246. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  247. """Forward one step.
  248. Args:
  249. tgt: input token ids, int64 (batch, maxlen_out)
  250. tgt_mask: input token mask, (batch, maxlen_out)
  251. dtype=torch.uint8 in PyTorch 1.2-
  252. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  253. memory: encoded memory, float32 (batch, maxlen_in, feat)
  254. cache: cached output list of (batch, max_time_out-1, size)
  255. Returns:
  256. y, cache: NN output value and cache per `self.decoders`.
  257. y.shape` is (batch, maxlen_out, token)
  258. """
  259. x = self.embed(tgt)
  260. if cache is None:
  261. cache = [None] * len(self.decoders)
  262. new_cache = []
  263. for c, decoder in zip(cache, self.decoders):
  264. x, tgt_mask, memory, memory_mask = decoder(
  265. x, tgt_mask, memory, None, cache=c
  266. )
  267. new_cache.append(x)
  268. if self.normalize_before:
  269. y = self.after_norm(x[:, -1])
  270. else:
  271. y = x[:, -1]
  272. if self.output_layer is not None:
  273. y = torch.log_softmax(self.output_layer(y), dim=-1)
  274. return y, new_cache
  275. def score(self, ys, state, x):
  276. """Score."""
  277. ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
  278. logp, state = self.forward_one_step(
  279. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  280. )
  281. return logp.squeeze(0), state
  282. def batch_score(
  283. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  284. ) -> Tuple[torch.Tensor, List[Any]]:
  285. """Score new token batch.
  286. Args:
  287. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  288. states (List[Any]): Scorer states for prefix tokens.
  289. xs (torch.Tensor):
  290. The encoder feature that generates ys (n_batch, xlen, n_feat).
  291. Returns:
  292. tuple[torch.Tensor, List[Any]]: Tuple of
  293. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  294. and next state list for ys.
  295. """
  296. # merge states
  297. n_batch = len(ys)
  298. n_layers = len(self.decoders)
  299. if states[0] is None:
  300. batch_state = None
  301. else:
  302. # transpose state of [batch, layer] into [layer, batch]
  303. batch_state = [
  304. torch.stack([states[b][i] for b in range(n_batch)])
  305. for i in range(n_layers)
  306. ]
  307. # batch decoding
  308. ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
  309. logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
  310. # transpose state of [layer, batch] into [batch, layer]
  311. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  312. return logp, state_list
  313. class TransformerDecoder(BaseTransformerDecoder):
  314. def __init__(
  315. self,
  316. vocab_size: int,
  317. encoder_output_size: int,
  318. attention_heads: int = 4,
  319. linear_units: int = 2048,
  320. num_blocks: int = 6,
  321. dropout_rate: float = 0.1,
  322. positional_dropout_rate: float = 0.1,
  323. self_attention_dropout_rate: float = 0.0,
  324. src_attention_dropout_rate: float = 0.0,
  325. input_layer: str = "embed",
  326. use_output_layer: bool = True,
  327. pos_enc_class=PositionalEncoding,
  328. normalize_before: bool = True,
  329. concat_after: bool = False,
  330. ):
  331. assert check_argument_types()
  332. super().__init__(
  333. vocab_size=vocab_size,
  334. encoder_output_size=encoder_output_size,
  335. dropout_rate=dropout_rate,
  336. positional_dropout_rate=positional_dropout_rate,
  337. input_layer=input_layer,
  338. use_output_layer=use_output_layer,
  339. pos_enc_class=pos_enc_class,
  340. normalize_before=normalize_before,
  341. )
  342. attention_dim = encoder_output_size
  343. self.decoders = repeat(
  344. num_blocks,
  345. lambda lnum: DecoderLayer(
  346. attention_dim,
  347. MultiHeadedAttention(
  348. attention_heads, attention_dim, self_attention_dropout_rate
  349. ),
  350. MultiHeadedAttention(
  351. attention_heads, attention_dim, src_attention_dropout_rate
  352. ),
  353. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  354. dropout_rate,
  355. normalize_before,
  356. concat_after,
  357. ),
  358. )
  359. class ParaformerDecoderSAN(BaseTransformerDecoder):
  360. """
  361. Author: Speech Lab of DAMO Academy, Alibaba Group
  362. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  363. https://arxiv.org/abs/2006.01713
  364. """
  365. def __init__(
  366. self,
  367. vocab_size: int,
  368. encoder_output_size: int,
  369. attention_heads: int = 4,
  370. linear_units: int = 2048,
  371. num_blocks: int = 6,
  372. dropout_rate: float = 0.1,
  373. positional_dropout_rate: float = 0.1,
  374. self_attention_dropout_rate: float = 0.0,
  375. src_attention_dropout_rate: float = 0.0,
  376. input_layer: str = "embed",
  377. use_output_layer: bool = True,
  378. pos_enc_class=PositionalEncoding,
  379. normalize_before: bool = True,
  380. concat_after: bool = False,
  381. embeds_id: int = -1,
  382. ):
  383. assert check_argument_types()
  384. super().__init__(
  385. vocab_size=vocab_size,
  386. encoder_output_size=encoder_output_size,
  387. dropout_rate=dropout_rate,
  388. positional_dropout_rate=positional_dropout_rate,
  389. input_layer=input_layer,
  390. use_output_layer=use_output_layer,
  391. pos_enc_class=pos_enc_class,
  392. normalize_before=normalize_before,
  393. )
  394. attention_dim = encoder_output_size
  395. self.decoders = repeat(
  396. num_blocks,
  397. lambda lnum: DecoderLayer(
  398. attention_dim,
  399. MultiHeadedAttention(
  400. attention_heads, attention_dim, self_attention_dropout_rate
  401. ),
  402. MultiHeadedAttention(
  403. attention_heads, attention_dim, src_attention_dropout_rate
  404. ),
  405. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  406. dropout_rate,
  407. normalize_before,
  408. concat_after,
  409. ),
  410. )
  411. self.embeds_id = embeds_id
  412. self.attention_dim = attention_dim
  413. def forward(
  414. self,
  415. hs_pad: torch.Tensor,
  416. hlens: torch.Tensor,
  417. ys_in_pad: torch.Tensor,
  418. ys_in_lens: torch.Tensor,
  419. ) -> Tuple[torch.Tensor, torch.Tensor]:
  420. """Forward decoder.
  421. Args:
  422. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  423. hlens: (batch)
  424. ys_in_pad:
  425. input token ids, int64 (batch, maxlen_out)
  426. if input_layer == "embed"
  427. input tensor (batch, maxlen_out, #mels) in the other cases
  428. ys_in_lens: (batch)
  429. Returns:
  430. (tuple): tuple containing:
  431. x: decoded token score before softmax (batch, maxlen_out, token)
  432. if use_output_layer is True,
  433. olens: (batch, )
  434. """
  435. tgt = ys_in_pad
  436. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  437. memory = hs_pad
  438. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  439. memory.device
  440. )
  441. # Padding for Longformer
  442. if memory_mask.shape[-1] != memory.shape[1]:
  443. padlen = memory.shape[1] - memory_mask.shape[-1]
  444. memory_mask = torch.nn.functional.pad(
  445. memory_mask, (0, padlen), "constant", False
  446. )
  447. # x = self.embed(tgt)
  448. x = tgt
  449. embeds_outputs = None
  450. for layer_id, decoder in enumerate(self.decoders):
  451. x, tgt_mask, memory, memory_mask = decoder(
  452. x, tgt_mask, memory, memory_mask
  453. )
  454. if layer_id == self.embeds_id:
  455. embeds_outputs = x
  456. if self.normalize_before:
  457. x = self.after_norm(x)
  458. if self.output_layer is not None:
  459. x = self.output_layer(x)
  460. olens = tgt_mask.sum(1)
  461. if embeds_outputs is not None:
  462. return x, olens, embeds_outputs
  463. else:
  464. return x, olens
  465. class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
  466. def __init__(
  467. self,
  468. vocab_size: int,
  469. encoder_output_size: int,
  470. attention_heads: int = 4,
  471. linear_units: int = 2048,
  472. num_blocks: int = 6,
  473. dropout_rate: float = 0.1,
  474. positional_dropout_rate: float = 0.1,
  475. self_attention_dropout_rate: float = 0.0,
  476. src_attention_dropout_rate: float = 0.0,
  477. input_layer: str = "embed",
  478. use_output_layer: bool = True,
  479. pos_enc_class=PositionalEncoding,
  480. normalize_before: bool = True,
  481. concat_after: bool = False,
  482. conv_wshare: int = 4,
  483. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  484. conv_usebias: int = False,
  485. ):
  486. assert check_argument_types()
  487. if len(conv_kernel_length) != num_blocks:
  488. raise ValueError(
  489. "conv_kernel_length must have equal number of values to num_blocks: "
  490. f"{len(conv_kernel_length)} != {num_blocks}"
  491. )
  492. super().__init__(
  493. vocab_size=vocab_size,
  494. encoder_output_size=encoder_output_size,
  495. dropout_rate=dropout_rate,
  496. positional_dropout_rate=positional_dropout_rate,
  497. input_layer=input_layer,
  498. use_output_layer=use_output_layer,
  499. pos_enc_class=pos_enc_class,
  500. normalize_before=normalize_before,
  501. )
  502. attention_dim = encoder_output_size
  503. self.decoders = repeat(
  504. num_blocks,
  505. lambda lnum: DecoderLayer(
  506. attention_dim,
  507. LightweightConvolution(
  508. wshare=conv_wshare,
  509. n_feat=attention_dim,
  510. dropout_rate=self_attention_dropout_rate,
  511. kernel_size=conv_kernel_length[lnum],
  512. use_kernel_mask=True,
  513. use_bias=conv_usebias,
  514. ),
  515. MultiHeadedAttention(
  516. attention_heads, attention_dim, src_attention_dropout_rate
  517. ),
  518. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  519. dropout_rate,
  520. normalize_before,
  521. concat_after,
  522. ),
  523. )
  524. class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
  525. def __init__(
  526. self,
  527. vocab_size: int,
  528. encoder_output_size: int,
  529. attention_heads: int = 4,
  530. linear_units: int = 2048,
  531. num_blocks: int = 6,
  532. dropout_rate: float = 0.1,
  533. positional_dropout_rate: float = 0.1,
  534. self_attention_dropout_rate: float = 0.0,
  535. src_attention_dropout_rate: float = 0.0,
  536. input_layer: str = "embed",
  537. use_output_layer: bool = True,
  538. pos_enc_class=PositionalEncoding,
  539. normalize_before: bool = True,
  540. concat_after: bool = False,
  541. conv_wshare: int = 4,
  542. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  543. conv_usebias: int = False,
  544. ):
  545. assert check_argument_types()
  546. if len(conv_kernel_length) != num_blocks:
  547. raise ValueError(
  548. "conv_kernel_length must have equal number of values to num_blocks: "
  549. f"{len(conv_kernel_length)} != {num_blocks}"
  550. )
  551. super().__init__(
  552. vocab_size=vocab_size,
  553. encoder_output_size=encoder_output_size,
  554. dropout_rate=dropout_rate,
  555. positional_dropout_rate=positional_dropout_rate,
  556. input_layer=input_layer,
  557. use_output_layer=use_output_layer,
  558. pos_enc_class=pos_enc_class,
  559. normalize_before=normalize_before,
  560. )
  561. attention_dim = encoder_output_size
  562. self.decoders = repeat(
  563. num_blocks,
  564. lambda lnum: DecoderLayer(
  565. attention_dim,
  566. LightweightConvolution2D(
  567. wshare=conv_wshare,
  568. n_feat=attention_dim,
  569. dropout_rate=self_attention_dropout_rate,
  570. kernel_size=conv_kernel_length[lnum],
  571. use_kernel_mask=True,
  572. use_bias=conv_usebias,
  573. ),
  574. MultiHeadedAttention(
  575. attention_heads, attention_dim, src_attention_dropout_rate
  576. ),
  577. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  578. dropout_rate,
  579. normalize_before,
  580. concat_after,
  581. ),
  582. )
  583. class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
  584. def __init__(
  585. self,
  586. vocab_size: int,
  587. encoder_output_size: int,
  588. attention_heads: int = 4,
  589. linear_units: int = 2048,
  590. num_blocks: int = 6,
  591. dropout_rate: float = 0.1,
  592. positional_dropout_rate: float = 0.1,
  593. self_attention_dropout_rate: float = 0.0,
  594. src_attention_dropout_rate: float = 0.0,
  595. input_layer: str = "embed",
  596. use_output_layer: bool = True,
  597. pos_enc_class=PositionalEncoding,
  598. normalize_before: bool = True,
  599. concat_after: bool = False,
  600. conv_wshare: int = 4,
  601. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  602. conv_usebias: int = False,
  603. ):
  604. assert check_argument_types()
  605. if len(conv_kernel_length) != num_blocks:
  606. raise ValueError(
  607. "conv_kernel_length must have equal number of values to num_blocks: "
  608. f"{len(conv_kernel_length)} != {num_blocks}"
  609. )
  610. super().__init__(
  611. vocab_size=vocab_size,
  612. encoder_output_size=encoder_output_size,
  613. dropout_rate=dropout_rate,
  614. positional_dropout_rate=positional_dropout_rate,
  615. input_layer=input_layer,
  616. use_output_layer=use_output_layer,
  617. pos_enc_class=pos_enc_class,
  618. normalize_before=normalize_before,
  619. )
  620. attention_dim = encoder_output_size
  621. self.decoders = repeat(
  622. num_blocks,
  623. lambda lnum: DecoderLayer(
  624. attention_dim,
  625. DynamicConvolution(
  626. wshare=conv_wshare,
  627. n_feat=attention_dim,
  628. dropout_rate=self_attention_dropout_rate,
  629. kernel_size=conv_kernel_length[lnum],
  630. use_kernel_mask=True,
  631. use_bias=conv_usebias,
  632. ),
  633. MultiHeadedAttention(
  634. attention_heads, attention_dim, src_attention_dropout_rate
  635. ),
  636. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  637. dropout_rate,
  638. normalize_before,
  639. concat_after,
  640. ),
  641. )
  642. class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
  643. def __init__(
  644. self,
  645. vocab_size: int,
  646. encoder_output_size: int,
  647. attention_heads: int = 4,
  648. linear_units: int = 2048,
  649. num_blocks: int = 6,
  650. dropout_rate: float = 0.1,
  651. positional_dropout_rate: float = 0.1,
  652. self_attention_dropout_rate: float = 0.0,
  653. src_attention_dropout_rate: float = 0.0,
  654. input_layer: str = "embed",
  655. use_output_layer: bool = True,
  656. pos_enc_class=PositionalEncoding,
  657. normalize_before: bool = True,
  658. concat_after: bool = False,
  659. conv_wshare: int = 4,
  660. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  661. conv_usebias: int = False,
  662. ):
  663. assert check_argument_types()
  664. if len(conv_kernel_length) != num_blocks:
  665. raise ValueError(
  666. "conv_kernel_length must have equal number of values to num_blocks: "
  667. f"{len(conv_kernel_length)} != {num_blocks}"
  668. )
  669. super().__init__(
  670. vocab_size=vocab_size,
  671. encoder_output_size=encoder_output_size,
  672. dropout_rate=dropout_rate,
  673. positional_dropout_rate=positional_dropout_rate,
  674. input_layer=input_layer,
  675. use_output_layer=use_output_layer,
  676. pos_enc_class=pos_enc_class,
  677. normalize_before=normalize_before,
  678. )
  679. attention_dim = encoder_output_size
  680. self.decoders = repeat(
  681. num_blocks,
  682. lambda lnum: DecoderLayer(
  683. attention_dim,
  684. DynamicConvolution2D(
  685. wshare=conv_wshare,
  686. n_feat=attention_dim,
  687. dropout_rate=self_attention_dropout_rate,
  688. kernel_size=conv_kernel_length[lnum],
  689. use_kernel_mask=True,
  690. use_bias=conv_usebias,
  691. ),
  692. MultiHeadedAttention(
  693. attention_heads, attention_dim, src_attention_dropout_rate
  694. ),
  695. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  696. dropout_rate,
  697. normalize_before,
  698. concat_after,
  699. ),
  700. )
  701. class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
  702. def __init__(
  703. self,
  704. vocab_size: int,
  705. encoder_output_size: int,
  706. spker_embedding_dim: int = 256,
  707. dropout_rate: float = 0.1,
  708. positional_dropout_rate: float = 0.1,
  709. input_layer: str = "embed",
  710. use_asr_output_layer: bool = True,
  711. use_spk_output_layer: bool = True,
  712. pos_enc_class=PositionalEncoding,
  713. normalize_before: bool = True,
  714. ):
  715. assert check_argument_types()
  716. super().__init__()
  717. attention_dim = encoder_output_size
  718. if input_layer == "embed":
  719. self.embed = torch.nn.Sequential(
  720. torch.nn.Embedding(vocab_size, attention_dim),
  721. pos_enc_class(attention_dim, positional_dropout_rate),
  722. )
  723. elif input_layer == "linear":
  724. self.embed = torch.nn.Sequential(
  725. torch.nn.Linear(vocab_size, attention_dim),
  726. torch.nn.LayerNorm(attention_dim),
  727. torch.nn.Dropout(dropout_rate),
  728. torch.nn.ReLU(),
  729. pos_enc_class(attention_dim, positional_dropout_rate),
  730. )
  731. else:
  732. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  733. self.normalize_before = normalize_before
  734. if self.normalize_before:
  735. self.after_norm = LayerNorm(attention_dim)
  736. if use_asr_output_layer:
  737. self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
  738. else:
  739. self.asr_output_layer = None
  740. if use_spk_output_layer:
  741. self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
  742. else:
  743. self.spk_output_layer = None
  744. self.cos_distance_att = CosineDistanceAttention()
  745. self.decoder1 = None
  746. self.decoder2 = None
  747. self.decoder3 = None
  748. self.decoder4 = None
  749. def forward(
  750. self,
  751. asr_hs_pad: torch.Tensor,
  752. spk_hs_pad: torch.Tensor,
  753. hlens: torch.Tensor,
  754. ys_in_pad: torch.Tensor,
  755. ys_in_lens: torch.Tensor,
  756. profile: torch.Tensor,
  757. profile_lens: torch.Tensor,
  758. ) -> Tuple[torch.Tensor, torch.Tensor]:
  759. tgt = ys_in_pad
  760. # tgt_mask: (B, 1, L)
  761. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  762. # m: (1, L, L)
  763. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  764. # tgt_mask: (B, L, L)
  765. tgt_mask = tgt_mask & m
  766. asr_memory = asr_hs_pad
  767. spk_memory = spk_hs_pad
  768. memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
  769. # Spk decoder
  770. x = self.embed(tgt)
  771. x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
  772. x, tgt_mask, asr_memory, spk_memory, memory_mask
  773. )
  774. x, tgt_mask, spk_memory, memory_mask = self.decoder2(
  775. x, tgt_mask, spk_memory, memory_mask
  776. )
  777. if self.normalize_before:
  778. x = self.after_norm(x)
  779. if self.spk_output_layer is not None:
  780. x = self.spk_output_layer(x)
  781. dn, weights = self.cos_distance_att(x, profile, profile_lens)
  782. # Asr decoder
  783. x, tgt_mask, asr_memory, memory_mask = self.decoder3(
  784. z, tgt_mask, asr_memory, memory_mask, dn
  785. )
  786. x, tgt_mask, asr_memory, memory_mask = self.decoder4(
  787. x, tgt_mask, asr_memory, memory_mask
  788. )
  789. if self.normalize_before:
  790. x = self.after_norm(x)
  791. if self.asr_output_layer is not None:
  792. x = self.asr_output_layer(x)
  793. olens = tgt_mask.sum(1)
  794. return x, weights, olens
  795. def forward_one_step(
  796. self,
  797. tgt: torch.Tensor,
  798. tgt_mask: torch.Tensor,
  799. asr_memory: torch.Tensor,
  800. spk_memory: torch.Tensor,
  801. profile: torch.Tensor,
  802. cache: List[torch.Tensor] = None,
  803. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  804. x = self.embed(tgt)
  805. if cache is None:
  806. cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
  807. new_cache = []
  808. x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
  809. x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
  810. )
  811. new_cache.append(x)
  812. for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
  813. x, tgt_mask, spk_memory, _ = decoder(
  814. x, tgt_mask, spk_memory, None, cache=c
  815. )
  816. new_cache.append(x)
  817. if self.normalize_before:
  818. x = self.after_norm(x)
  819. else:
  820. x = x
  821. if self.spk_output_layer is not None:
  822. x = self.spk_output_layer(x)
  823. dn, weights = self.cos_distance_att(x, profile, None)
  824. x, tgt_mask, asr_memory, _ = self.decoder3(
  825. z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
  826. )
  827. new_cache.append(x)
  828. for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
  829. x, tgt_mask, asr_memory, _ = decoder(
  830. x, tgt_mask, asr_memory, None, cache=c
  831. )
  832. new_cache.append(x)
  833. if self.normalize_before:
  834. y = self.after_norm(x[:, -1])
  835. else:
  836. y = x[:, -1]
  837. if self.asr_output_layer is not None:
  838. y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
  839. return y, weights, new_cache
  840. def score(self, ys, state, asr_enc, spk_enc, profile):
  841. """Score."""
  842. ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
  843. logp, weights, state = self.forward_one_step(
  844. ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
  845. )
  846. return logp.squeeze(0), weights.squeeze(), state
  847. class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
  848. def __init__(
  849. self,
  850. vocab_size: int,
  851. encoder_output_size: int,
  852. spker_embedding_dim: int = 256,
  853. attention_heads: int = 4,
  854. linear_units: int = 2048,
  855. asr_num_blocks: int = 6,
  856. spk_num_blocks: int = 3,
  857. dropout_rate: float = 0.1,
  858. positional_dropout_rate: float = 0.1,
  859. self_attention_dropout_rate: float = 0.0,
  860. src_attention_dropout_rate: float = 0.0,
  861. input_layer: str = "embed",
  862. use_asr_output_layer: bool = True,
  863. use_spk_output_layer: bool = True,
  864. pos_enc_class=PositionalEncoding,
  865. normalize_before: bool = True,
  866. concat_after: bool = False,
  867. ):
  868. assert check_argument_types()
  869. super().__init__(
  870. vocab_size=vocab_size,
  871. encoder_output_size=encoder_output_size,
  872. spker_embedding_dim=spker_embedding_dim,
  873. dropout_rate=dropout_rate,
  874. positional_dropout_rate=positional_dropout_rate,
  875. input_layer=input_layer,
  876. use_asr_output_layer=use_asr_output_layer,
  877. use_spk_output_layer=use_spk_output_layer,
  878. pos_enc_class=pos_enc_class,
  879. normalize_before=normalize_before,
  880. )
  881. attention_dim = encoder_output_size
  882. self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
  883. attention_dim,
  884. MultiHeadedAttention(
  885. attention_heads, attention_dim, self_attention_dropout_rate
  886. ),
  887. MultiHeadedAttention(
  888. attention_heads, attention_dim, src_attention_dropout_rate
  889. ),
  890. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  891. dropout_rate,
  892. normalize_before,
  893. concat_after,
  894. )
  895. self.decoder2 = repeat(
  896. spk_num_blocks - 1,
  897. lambda lnum: DecoderLayer(
  898. attention_dim,
  899. MultiHeadedAttention(
  900. attention_heads, attention_dim, self_attention_dropout_rate
  901. ),
  902. MultiHeadedAttention(
  903. attention_heads, attention_dim, src_attention_dropout_rate
  904. ),
  905. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  906. dropout_rate,
  907. normalize_before,
  908. concat_after,
  909. ),
  910. )
  911. self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
  912. attention_dim,
  913. spker_embedding_dim,
  914. MultiHeadedAttention(
  915. attention_heads, attention_dim, src_attention_dropout_rate
  916. ),
  917. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  918. dropout_rate,
  919. normalize_before,
  920. concat_after,
  921. )
  922. self.decoder4 = repeat(
  923. asr_num_blocks - 1,
  924. lambda lnum: DecoderLayer(
  925. attention_dim,
  926. MultiHeadedAttention(
  927. attention_heads, attention_dim, self_attention_dropout_rate
  928. ),
  929. MultiHeadedAttention(
  930. attention_heads, attention_dim, src_attention_dropout_rate
  931. ),
  932. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  933. dropout_rate,
  934. normalize_before,
  935. concat_after,
  936. ),
  937. )
  938. class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
  939. def __init__(
  940. self,
  941. size,
  942. self_attn,
  943. src_attn,
  944. feed_forward,
  945. dropout_rate,
  946. normalize_before=True,
  947. concat_after=False,
  948. ):
  949. """Construct an DecoderLayer object."""
  950. super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
  951. self.size = size
  952. self.self_attn = self_attn
  953. self.src_attn = src_attn
  954. self.feed_forward = feed_forward
  955. self.norm1 = LayerNorm(size)
  956. self.norm2 = LayerNorm(size)
  957. self.dropout = nn.Dropout(dropout_rate)
  958. self.normalize_before = normalize_before
  959. self.concat_after = concat_after
  960. if self.concat_after:
  961. self.concat_linear1 = nn.Linear(size + size, size)
  962. self.concat_linear2 = nn.Linear(size + size, size)
  963. def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
  964. residual = tgt
  965. if self.normalize_before:
  966. tgt = self.norm1(tgt)
  967. if cache is None:
  968. tgt_q = tgt
  969. tgt_q_mask = tgt_mask
  970. else:
  971. # compute only the last frame query keeping dim: max_time_out -> 1
  972. assert cache.shape == (
  973. tgt.shape[0],
  974. tgt.shape[1] - 1,
  975. self.size,
  976. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  977. tgt_q = tgt[:, -1:, :]
  978. residual = residual[:, -1:, :]
  979. tgt_q_mask = None
  980. if tgt_mask is not None:
  981. tgt_q_mask = tgt_mask[:, -1:, :]
  982. if self.concat_after:
  983. tgt_concat = torch.cat(
  984. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  985. )
  986. x = residual + self.concat_linear1(tgt_concat)
  987. else:
  988. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  989. if not self.normalize_before:
  990. x = self.norm1(x)
  991. z = x
  992. residual = x
  993. if self.normalize_before:
  994. x = self.norm1(x)
  995. skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
  996. if self.concat_after:
  997. x_concat = torch.cat(
  998. (x, skip), dim=-1
  999. )
  1000. x = residual + self.concat_linear2(x_concat)
  1001. else:
  1002. x = residual + self.dropout(skip)
  1003. if not self.normalize_before:
  1004. x = self.norm1(x)
  1005. residual = x
  1006. if self.normalize_before:
  1007. x = self.norm2(x)
  1008. x = residual + self.dropout(self.feed_forward(x))
  1009. if not self.normalize_before:
  1010. x = self.norm2(x)
  1011. if cache is not None:
  1012. x = torch.cat([cache, x], dim=1)
  1013. return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
  1014. class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
  1015. def __init__(
  1016. self,
  1017. size,
  1018. d_size,
  1019. src_attn,
  1020. feed_forward,
  1021. dropout_rate,
  1022. normalize_before=True,
  1023. concat_after=False,
  1024. ):
  1025. """Construct an DecoderLayer object."""
  1026. super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
  1027. self.size = size
  1028. self.src_attn = src_attn
  1029. self.feed_forward = feed_forward
  1030. self.norm1 = LayerNorm(size)
  1031. self.norm2 = LayerNorm(size)
  1032. self.norm3 = LayerNorm(size)
  1033. self.dropout = nn.Dropout(dropout_rate)
  1034. self.normalize_before = normalize_before
  1035. self.concat_after = concat_after
  1036. self.spk_linear = nn.Linear(d_size, size, bias=False)
  1037. if self.concat_after:
  1038. self.concat_linear1 = nn.Linear(size + size, size)
  1039. self.concat_linear2 = nn.Linear(size + size, size)
  1040. def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
  1041. residual = tgt
  1042. if self.normalize_before:
  1043. tgt = self.norm1(tgt)
  1044. if cache is None:
  1045. tgt_q = tgt
  1046. tgt_q_mask = tgt_mask
  1047. else:
  1048. tgt_q = tgt[:, -1:, :]
  1049. residual = residual[:, -1:, :]
  1050. tgt_q_mask = None
  1051. if tgt_mask is not None:
  1052. tgt_q_mask = tgt_mask[:, -1:, :]
  1053. x = tgt_q
  1054. if self.normalize_before:
  1055. x = self.norm2(x)
  1056. if self.concat_after:
  1057. x_concat = torch.cat(
  1058. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  1059. )
  1060. x = residual + self.concat_linear2(x_concat)
  1061. else:
  1062. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  1063. if not self.normalize_before:
  1064. x = self.norm2(x)
  1065. residual = x
  1066. if dn!=None:
  1067. x = x + self.spk_linear(dn)
  1068. if self.normalize_before:
  1069. x = self.norm3(x)
  1070. x = residual + self.dropout(self.feed_forward(x))
  1071. if not self.normalize_before:
  1072. x = self.norm3(x)
  1073. if cache is not None:
  1074. x = torch.cat([cache, x], dim=1)
  1075. return x, tgt_mask, memory, memory_mask