sanm_decoder.py 70 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386
  1. from typing import List
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from funasr.modules.streaming_utils import utils as myutils
  8. from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
  9. from typeguard import check_argument_types
  10. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  11. from funasr.modules.embedding import PositionalEncoding
  12. from funasr.modules.layer_norm import LayerNorm
  13. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  14. from funasr.modules.repeat import repeat
  15. class DecoderLayerSANM(nn.Module):
  16. """Single decoder layer module.
  17. Args:
  18. size (int): Input dimension.
  19. self_attn (torch.nn.Module): Self-attention module instance.
  20. `MultiHeadedAttention` instance can be used as the argument.
  21. src_attn (torch.nn.Module): Self-attention module instance.
  22. `MultiHeadedAttention` instance can be used as the argument.
  23. feed_forward (torch.nn.Module): Feed-forward module instance.
  24. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  25. can be used as the argument.
  26. dropout_rate (float): Dropout rate.
  27. normalize_before (bool): Whether to use layer_norm before the first block.
  28. concat_after (bool): Whether to concat attention layer's input and output.
  29. if True, additional linear will be applied.
  30. i.e. x -> x + linear(concat(x, att(x)))
  31. if False, no additional linear will be applied. i.e. x -> x + att(x)
  32. """
  33. def __init__(
  34. self,
  35. size,
  36. self_attn,
  37. src_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. ):
  43. """Construct an DecoderLayer object."""
  44. super(DecoderLayerSANM, self).__init__()
  45. self.size = size
  46. self.self_attn = self_attn
  47. self.src_attn = src_attn
  48. self.feed_forward = feed_forward
  49. self.norm1 = LayerNorm(size)
  50. if self_attn is not None:
  51. self.norm2 = LayerNorm(size)
  52. if src_attn is not None:
  53. self.norm3 = LayerNorm(size)
  54. self.dropout = nn.Dropout(dropout_rate)
  55. self.normalize_before = normalize_before
  56. self.concat_after = concat_after
  57. if self.concat_after:
  58. self.concat_linear1 = nn.Linear(size + size, size)
  59. self.concat_linear2 = nn.Linear(size + size, size)
  60. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  61. """Compute decoded features.
  62. Args:
  63. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  64. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  65. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  66. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  67. cache (List[torch.Tensor]): List of cached tensors.
  68. Each tensor shape should be (#batch, maxlen_out - 1, size).
  69. Returns:
  70. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  71. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  72. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  73. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  74. """
  75. # tgt = self.dropout(tgt)
  76. residual = tgt
  77. if self.normalize_before:
  78. tgt = self.norm1(tgt)
  79. tgt = self.feed_forward(tgt)
  80. x = tgt
  81. if self.self_attn:
  82. if self.normalize_before:
  83. tgt = self.norm2(tgt)
  84. if self.training:
  85. cache = None
  86. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  87. x = residual + self.dropout(x)
  88. if self.src_attn is not None:
  89. residual = x
  90. if self.normalize_before:
  91. x = self.norm3(x)
  92. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  93. return x, tgt_mask, memory, memory_mask, cache
  94. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  95. """
  96. author: Speech Lab, Alibaba Group, China
  97. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  98. https://arxiv.org/abs/2006.01713
  99. """
  100. def __init__(
  101. self,
  102. vocab_size: int,
  103. encoder_output_size: int,
  104. attention_heads: int = 4,
  105. linear_units: int = 2048,
  106. num_blocks: int = 6,
  107. dropout_rate: float = 0.1,
  108. positional_dropout_rate: float = 0.1,
  109. self_attention_dropout_rate: float = 0.0,
  110. src_attention_dropout_rate: float = 0.0,
  111. input_layer: str = "embed",
  112. use_output_layer: bool = True,
  113. pos_enc_class=PositionalEncoding,
  114. normalize_before: bool = True,
  115. concat_after: bool = False,
  116. att_layer_num: int = 6,
  117. kernel_size: int = 21,
  118. sanm_shfit: int = None,
  119. concat_embeds: bool = False,
  120. attention_dim: int = None,
  121. tf2torch_tensor_name_prefix_torch: str = "decoder",
  122. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  123. embed_tensor_name_prefix_tf: str = None,
  124. ):
  125. assert check_argument_types()
  126. super().__init__(
  127. vocab_size=vocab_size,
  128. encoder_output_size=encoder_output_size,
  129. dropout_rate=dropout_rate,
  130. positional_dropout_rate=positional_dropout_rate,
  131. input_layer=input_layer,
  132. use_output_layer=use_output_layer,
  133. pos_enc_class=pos_enc_class,
  134. normalize_before=normalize_before,
  135. )
  136. if attention_dim is None:
  137. attention_dim = encoder_output_size
  138. if input_layer == "embed":
  139. self.embed = torch.nn.Sequential(
  140. torch.nn.Embedding(vocab_size, attention_dim),
  141. )
  142. elif input_layer == "linear":
  143. self.embed = torch.nn.Sequential(
  144. torch.nn.Linear(vocab_size, attention_dim),
  145. torch.nn.LayerNorm(attention_dim),
  146. torch.nn.Dropout(dropout_rate),
  147. torch.nn.ReLU(),
  148. pos_enc_class(attention_dim, positional_dropout_rate),
  149. )
  150. else:
  151. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  152. self.normalize_before = normalize_before
  153. if self.normalize_before:
  154. self.after_norm = LayerNorm(attention_dim)
  155. if use_output_layer:
  156. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  157. else:
  158. self.output_layer = None
  159. self.att_layer_num = att_layer_num
  160. self.num_blocks = num_blocks
  161. if sanm_shfit is None:
  162. sanm_shfit = (kernel_size - 1) // 2
  163. self.decoders = repeat(
  164. att_layer_num,
  165. lambda lnum: DecoderLayerSANM(
  166. attention_dim,
  167. MultiHeadedAttentionSANMDecoder(
  168. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  169. ),
  170. MultiHeadedAttentionCrossAtt(
  171. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  172. ),
  173. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  174. dropout_rate,
  175. normalize_before,
  176. concat_after,
  177. ),
  178. )
  179. if num_blocks - att_layer_num <= 0:
  180. self.decoders2 = None
  181. else:
  182. self.decoders2 = repeat(
  183. num_blocks - att_layer_num,
  184. lambda lnum: DecoderLayerSANM(
  185. attention_dim,
  186. MultiHeadedAttentionSANMDecoder(
  187. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  188. ),
  189. None,
  190. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  191. dropout_rate,
  192. normalize_before,
  193. concat_after,
  194. ),
  195. )
  196. self.decoders3 = repeat(
  197. 1,
  198. lambda lnum: DecoderLayerSANM(
  199. attention_dim,
  200. None,
  201. None,
  202. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  203. dropout_rate,
  204. normalize_before,
  205. concat_after,
  206. ),
  207. )
  208. if concat_embeds:
  209. self.embed_concat_ffn = repeat(
  210. 1,
  211. lambda lnum: DecoderLayerSANM(
  212. attention_dim + encoder_output_size,
  213. None,
  214. None,
  215. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  216. adim=attention_dim),
  217. dropout_rate,
  218. normalize_before,
  219. concat_after,
  220. ),
  221. )
  222. else:
  223. self.embed_concat_ffn = None
  224. self.concat_embeds = concat_embeds
  225. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  226. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  227. self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
  228. def forward(
  229. self,
  230. hs_pad: torch.Tensor,
  231. hlens: torch.Tensor,
  232. ys_in_pad: torch.Tensor,
  233. ys_in_lens: torch.Tensor,
  234. chunk_mask: torch.Tensor = None,
  235. pre_acoustic_embeds: torch.Tensor = None,
  236. ) -> Tuple[torch.Tensor, torch.Tensor]:
  237. """Forward decoder.
  238. Args:
  239. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  240. hlens: (batch)
  241. ys_in_pad:
  242. input token ids, int64 (batch, maxlen_out)
  243. if input_layer == "embed"
  244. input tensor (batch, maxlen_out, #mels) in the other cases
  245. ys_in_lens: (batch)
  246. Returns:
  247. (tuple): tuple containing:
  248. x: decoded token score before softmax (batch, maxlen_out, token)
  249. if use_output_layer is True,
  250. olens: (batch, )
  251. """
  252. tgt = ys_in_pad
  253. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  254. memory = hs_pad
  255. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  256. if chunk_mask is not None:
  257. memory_mask = memory_mask * chunk_mask
  258. if tgt_mask.size(1) != memory_mask.size(1):
  259. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  260. x = self.embed(tgt)
  261. if pre_acoustic_embeds is not None and self.concat_embeds:
  262. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  263. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  264. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  265. x, tgt_mask, memory, memory_mask
  266. )
  267. if self.decoders2 is not None:
  268. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  269. x, tgt_mask, memory, memory_mask
  270. )
  271. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  272. x, tgt_mask, memory, memory_mask
  273. )
  274. if self.normalize_before:
  275. x = self.after_norm(x)
  276. if self.output_layer is not None:
  277. x = self.output_layer(x)
  278. olens = tgt_mask.sum(1)
  279. return x, olens
  280. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  281. """Score."""
  282. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  283. logp, state = self.forward_one_step(
  284. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  285. cache=state
  286. )
  287. return logp.squeeze(0), state
  288. def forward_one_step(
  289. self,
  290. tgt: torch.Tensor,
  291. tgt_mask: torch.Tensor,
  292. memory: torch.Tensor,
  293. memory_mask: torch.Tensor = None,
  294. pre_acoustic_embeds: torch.Tensor = None,
  295. cache: List[torch.Tensor] = None,
  296. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  297. """Forward one step.
  298. Args:
  299. tgt: input token ids, int64 (batch, maxlen_out)
  300. tgt_mask: input token mask, (batch, maxlen_out)
  301. dtype=torch.uint8 in PyTorch 1.2-
  302. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  303. memory: encoded memory, float32 (batch, maxlen_in, feat)
  304. cache: cached output list of (batch, max_time_out-1, size)
  305. Returns:
  306. y, cache: NN output value and cache per `self.decoders`.
  307. y.shape` is (batch, maxlen_out, token)
  308. """
  309. x = tgt[:, -1:]
  310. tgt_mask = None
  311. x = self.embed(x)
  312. if pre_acoustic_embeds is not None and self.concat_embeds:
  313. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  314. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  315. if cache is None:
  316. cache_layer_num = len(self.decoders)
  317. if self.decoders2 is not None:
  318. cache_layer_num += len(self.decoders2)
  319. cache = [None] * cache_layer_num
  320. new_cache = []
  321. # for c, decoder in zip(cache, self.decoders):
  322. for i in range(self.att_layer_num):
  323. decoder = self.decoders[i]
  324. c = cache[i]
  325. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  326. x, tgt_mask, memory, memory_mask, cache=c
  327. )
  328. new_cache.append(c_ret)
  329. if self.num_blocks - self.att_layer_num >= 1:
  330. for i in range(self.num_blocks - self.att_layer_num):
  331. j = i + self.att_layer_num
  332. decoder = self.decoders2[i]
  333. c = cache[j]
  334. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  335. x, tgt_mask, memory, memory_mask, cache=c
  336. )
  337. new_cache.append(c_ret)
  338. for decoder in self.decoders3:
  339. x, tgt_mask, memory, memory_mask, _ = decoder(
  340. x, tgt_mask, memory, None, cache=None
  341. )
  342. if self.normalize_before:
  343. y = self.after_norm(x[:, -1])
  344. else:
  345. y = x[:, -1]
  346. if self.output_layer is not None:
  347. y = self.output_layer(y)
  348. y = torch.log_softmax(y, dim=-1)
  349. return y, new_cache
  350. def gen_tf2torch_map_dict(self):
  351. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  352. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  353. embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
  354. map_dict_local = {
  355. ## decoder
  356. # ffn
  357. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  358. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  359. "squeeze": None,
  360. "transpose": None,
  361. }, # (256,),(256,)
  362. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  363. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  364. "squeeze": None,
  365. "transpose": None,
  366. }, # (256,),(256,)
  367. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  368. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  369. "squeeze": 0,
  370. "transpose": (1, 0),
  371. }, # (1024,256),(1,256,1024)
  372. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  373. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  374. "squeeze": None,
  375. "transpose": None,
  376. }, # (1024,),(1024,)
  377. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  378. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  379. "squeeze": None,
  380. "transpose": None,
  381. }, # (1024,),(1024,)
  382. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  383. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  384. "squeeze": None,
  385. "transpose": None,
  386. }, # (1024,),(1024,)
  387. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  388. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  389. "squeeze": 0,
  390. "transpose": (1, 0),
  391. }, # (256,1024),(1,1024,256)
  392. # fsmn
  393. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  394. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  395. tensor_name_prefix_tf),
  396. "squeeze": None,
  397. "transpose": None,
  398. }, # (256,),(256,)
  399. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  400. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  401. tensor_name_prefix_tf),
  402. "squeeze": None,
  403. "transpose": None,
  404. }, # (256,),(256,)
  405. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  406. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  407. tensor_name_prefix_tf),
  408. "squeeze": 0,
  409. "transpose": (1, 2, 0),
  410. }, # (256,1,31),(1,31,256,1)
  411. # src att
  412. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  413. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  414. "squeeze": None,
  415. "transpose": None,
  416. }, # (256,),(256,)
  417. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  418. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  419. "squeeze": None,
  420. "transpose": None,
  421. }, # (256,),(256,)
  422. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  423. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  424. "squeeze": 0,
  425. "transpose": (1, 0),
  426. }, # (256,256),(1,256,256)
  427. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  428. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  429. "squeeze": None,
  430. "transpose": None,
  431. }, # (256,),(256,)
  432. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  433. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  434. "squeeze": 0,
  435. "transpose": (1, 0),
  436. }, # (1024,256),(1,256,1024)
  437. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  438. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  439. "squeeze": None,
  440. "transpose": None,
  441. }, # (1024,),(1024,)
  442. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  443. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  444. "squeeze": 0,
  445. "transpose": (1, 0),
  446. }, # (256,256),(1,256,256)
  447. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  448. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  449. "squeeze": None,
  450. "transpose": None,
  451. }, # (256,),(256,)
  452. # dnn
  453. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  454. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  455. "squeeze": None,
  456. "transpose": None,
  457. }, # (256,),(256,)
  458. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  459. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  460. "squeeze": None,
  461. "transpose": None,
  462. }, # (256,),(256,)
  463. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  464. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  465. "squeeze": 0,
  466. "transpose": (1, 0),
  467. }, # (1024,256),(1,256,1024)
  468. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  469. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  470. "squeeze": None,
  471. "transpose": None,
  472. }, # (1024,),(1024,)
  473. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  474. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  475. "squeeze": None,
  476. "transpose": None,
  477. }, # (1024,),(1024,)
  478. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  479. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  480. "squeeze": None,
  481. "transpose": None,
  482. }, # (1024,),(1024,)
  483. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  484. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  485. "squeeze": 0,
  486. "transpose": (1, 0),
  487. }, # (256,1024),(1,1024,256)
  488. # embed_concat_ffn
  489. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  490. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  491. "squeeze": None,
  492. "transpose": None,
  493. }, # (256,),(256,)
  494. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  495. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  496. "squeeze": None,
  497. "transpose": None,
  498. }, # (256,),(256,)
  499. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  500. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  501. "squeeze": 0,
  502. "transpose": (1, 0),
  503. }, # (1024,256),(1,256,1024)
  504. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  505. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  506. "squeeze": None,
  507. "transpose": None,
  508. }, # (1024,),(1024,)
  509. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  510. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  511. "squeeze": None,
  512. "transpose": None,
  513. }, # (1024,),(1024,)
  514. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  515. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  516. "squeeze": None,
  517. "transpose": None,
  518. }, # (1024,),(1024,)
  519. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  520. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  521. "squeeze": 0,
  522. "transpose": (1, 0),
  523. }, # (256,1024),(1,1024,256)
  524. # out norm
  525. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  526. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  527. "squeeze": None,
  528. "transpose": None,
  529. }, # (256,),(256,)
  530. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  531. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  532. "squeeze": None,
  533. "transpose": None,
  534. }, # (256,),(256,)
  535. # in embed
  536. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  537. {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
  538. "squeeze": None,
  539. "transpose": None,
  540. }, # (4235,256),(4235,256)
  541. # out layer
  542. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  543. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
  544. "{}/w_embs".format(embed_tensor_name_prefix_tf)],
  545. "squeeze": [None, None],
  546. "transpose": [(1, 0), None],
  547. }, # (4235,256),(256,4235)
  548. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  549. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  550. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  551. "squeeze": [None, None],
  552. "transpose": [None, None],
  553. }, # (4235,),(4235,)
  554. }
  555. return map_dict_local
  556. def convert_tf2torch(self,
  557. var_dict_tf,
  558. var_dict_torch,
  559. ):
  560. map_dict = self.gen_tf2torch_map_dict()
  561. var_dict_torch_update = dict()
  562. decoder_layeridx_sets = set()
  563. for name in sorted(var_dict_torch.keys(), reverse=False):
  564. names = name.split('.')
  565. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  566. if names[1] == "decoders":
  567. layeridx = int(names[2])
  568. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  569. layeridx_bias = 0
  570. layeridx += layeridx_bias
  571. decoder_layeridx_sets.add(layeridx)
  572. if name_q in map_dict.keys():
  573. name_v = map_dict[name_q]["name"]
  574. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  575. data_tf = var_dict_tf[name_tf]
  576. if map_dict[name_q]["squeeze"] is not None:
  577. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  578. if map_dict[name_q]["transpose"] is not None:
  579. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  580. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  581. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  582. var_dict_torch[
  583. name].size(),
  584. data_tf.size())
  585. var_dict_torch_update[name] = data_tf
  586. logging.info(
  587. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  588. var_dict_tf[name_tf].shape))
  589. elif names[1] == "decoders2":
  590. layeridx = int(names[2])
  591. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  592. name_q = name_q.replace("decoders2", "decoders")
  593. layeridx_bias = len(decoder_layeridx_sets)
  594. layeridx += layeridx_bias
  595. if "decoders." in name:
  596. decoder_layeridx_sets.add(layeridx)
  597. if name_q in map_dict.keys():
  598. name_v = map_dict[name_q]["name"]
  599. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  600. data_tf = var_dict_tf[name_tf]
  601. if map_dict[name_q]["squeeze"] is not None:
  602. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  603. if map_dict[name_q]["transpose"] is not None:
  604. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  605. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  606. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  607. var_dict_torch[
  608. name].size(),
  609. data_tf.size())
  610. var_dict_torch_update[name] = data_tf
  611. logging.info(
  612. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  613. var_dict_tf[name_tf].shape))
  614. elif names[1] == "decoders3":
  615. layeridx = int(names[2])
  616. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  617. layeridx_bias = 0
  618. layeridx += layeridx_bias
  619. if "decoders." in name:
  620. decoder_layeridx_sets.add(layeridx)
  621. if name_q in map_dict.keys():
  622. name_v = map_dict[name_q]["name"]
  623. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  624. data_tf = var_dict_tf[name_tf]
  625. if map_dict[name_q]["squeeze"] is not None:
  626. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  627. if map_dict[name_q]["transpose"] is not None:
  628. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  629. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  630. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  631. var_dict_torch[
  632. name].size(),
  633. data_tf.size())
  634. var_dict_torch_update[name] = data_tf
  635. logging.info(
  636. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  637. var_dict_tf[name_tf].shape))
  638. elif names[1] == "embed" or names[1] == "output_layer":
  639. name_tf = map_dict[name]["name"]
  640. if isinstance(name_tf, list):
  641. idx_list = 0
  642. if name_tf[idx_list] in var_dict_tf.keys():
  643. pass
  644. else:
  645. idx_list = 1
  646. data_tf = var_dict_tf[name_tf[idx_list]]
  647. if map_dict[name]["squeeze"][idx_list] is not None:
  648. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  649. if map_dict[name]["transpose"][idx_list] is not None:
  650. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  651. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  652. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  653. var_dict_torch[
  654. name].size(),
  655. data_tf.size())
  656. var_dict_torch_update[name] = data_tf
  657. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  658. name_tf[idx_list],
  659. var_dict_tf[name_tf[
  660. idx_list]].shape))
  661. else:
  662. data_tf = var_dict_tf[name_tf]
  663. if map_dict[name]["squeeze"] is not None:
  664. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  665. if map_dict[name]["transpose"] is not None:
  666. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  667. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  668. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  669. var_dict_torch[
  670. name].size(),
  671. data_tf.size())
  672. var_dict_torch_update[name] = data_tf
  673. logging.info(
  674. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  675. var_dict_tf[name_tf].shape))
  676. elif names[1] == "after_norm":
  677. name_tf = map_dict[name]["name"]
  678. data_tf = var_dict_tf[name_tf]
  679. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  680. var_dict_torch_update[name] = data_tf
  681. logging.info(
  682. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  683. var_dict_tf[name_tf].shape))
  684. elif names[1] == "embed_concat_ffn":
  685. layeridx = int(names[2])
  686. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  687. layeridx_bias = 0
  688. layeridx += layeridx_bias
  689. if "decoders." in name:
  690. decoder_layeridx_sets.add(layeridx)
  691. if name_q in map_dict.keys():
  692. name_v = map_dict[name_q]["name"]
  693. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  694. data_tf = var_dict_tf[name_tf]
  695. if map_dict[name_q]["squeeze"] is not None:
  696. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  697. if map_dict[name_q]["transpose"] is not None:
  698. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  699. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  700. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  701. var_dict_torch[
  702. name].size(),
  703. data_tf.size())
  704. var_dict_torch_update[name] = data_tf
  705. logging.info(
  706. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  707. var_dict_tf[name_tf].shape))
  708. return var_dict_torch_update
  709. class ParaformerSANMDecoder(BaseTransformerDecoder):
  710. """
  711. author: Speech Lab, Alibaba Group, China
  712. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  713. https://arxiv.org/abs/2006.01713
  714. """
  715. def __init__(
  716. self,
  717. vocab_size: int,
  718. encoder_output_size: int,
  719. attention_heads: int = 4,
  720. linear_units: int = 2048,
  721. num_blocks: int = 6,
  722. dropout_rate: float = 0.1,
  723. positional_dropout_rate: float = 0.1,
  724. self_attention_dropout_rate: float = 0.0,
  725. src_attention_dropout_rate: float = 0.0,
  726. input_layer: str = "embed",
  727. use_output_layer: bool = True,
  728. pos_enc_class=PositionalEncoding,
  729. normalize_before: bool = True,
  730. concat_after: bool = False,
  731. att_layer_num: int = 6,
  732. kernel_size: int = 21,
  733. sanm_shfit: int = 0,
  734. tf2torch_tensor_name_prefix_torch: str = "decoder",
  735. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  736. ):
  737. assert check_argument_types()
  738. super().__init__(
  739. vocab_size=vocab_size,
  740. encoder_output_size=encoder_output_size,
  741. dropout_rate=dropout_rate,
  742. positional_dropout_rate=positional_dropout_rate,
  743. input_layer=input_layer,
  744. use_output_layer=use_output_layer,
  745. pos_enc_class=pos_enc_class,
  746. normalize_before=normalize_before,
  747. )
  748. attention_dim = encoder_output_size
  749. if input_layer == "embed":
  750. self.embed = torch.nn.Sequential(
  751. torch.nn.Embedding(vocab_size, attention_dim),
  752. # pos_enc_class(attention_dim, positional_dropout_rate),
  753. )
  754. elif input_layer == "linear":
  755. self.embed = torch.nn.Sequential(
  756. torch.nn.Linear(vocab_size, attention_dim),
  757. torch.nn.LayerNorm(attention_dim),
  758. torch.nn.Dropout(dropout_rate),
  759. torch.nn.ReLU(),
  760. pos_enc_class(attention_dim, positional_dropout_rate),
  761. )
  762. else:
  763. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  764. self.normalize_before = normalize_before
  765. if self.normalize_before:
  766. self.after_norm = LayerNorm(attention_dim)
  767. if use_output_layer:
  768. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  769. else:
  770. self.output_layer = None
  771. self.att_layer_num = att_layer_num
  772. self.num_blocks = num_blocks
  773. if sanm_shfit is None:
  774. sanm_shfit = (kernel_size - 1) // 2
  775. self.decoders = repeat(
  776. att_layer_num,
  777. lambda lnum: DecoderLayerSANM(
  778. attention_dim,
  779. MultiHeadedAttentionSANMDecoder(
  780. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  781. ),
  782. MultiHeadedAttentionCrossAtt(
  783. attention_heads, attention_dim, src_attention_dropout_rate
  784. ),
  785. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  786. dropout_rate,
  787. normalize_before,
  788. concat_after,
  789. ),
  790. )
  791. if num_blocks - att_layer_num <= 0:
  792. self.decoders2 = None
  793. else:
  794. self.decoders2 = repeat(
  795. num_blocks - att_layer_num,
  796. lambda lnum: DecoderLayerSANM(
  797. attention_dim,
  798. MultiHeadedAttentionSANMDecoder(
  799. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  800. ),
  801. None,
  802. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  803. dropout_rate,
  804. normalize_before,
  805. concat_after,
  806. ),
  807. )
  808. self.decoders3 = repeat(
  809. 1,
  810. lambda lnum: DecoderLayerSANM(
  811. attention_dim,
  812. None,
  813. None,
  814. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  815. dropout_rate,
  816. normalize_before,
  817. concat_after,
  818. ),
  819. )
  820. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  821. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  822. def forward(
  823. self,
  824. hs_pad: torch.Tensor,
  825. hlens: torch.Tensor,
  826. ys_in_pad: torch.Tensor,
  827. ys_in_lens: torch.Tensor,
  828. ) -> Tuple[torch.Tensor, torch.Tensor]:
  829. """Forward decoder.
  830. Args:
  831. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  832. hlens: (batch)
  833. ys_in_pad:
  834. input token ids, int64 (batch, maxlen_out)
  835. if input_layer == "embed"
  836. input tensor (batch, maxlen_out, #mels) in the other cases
  837. ys_in_lens: (batch)
  838. Returns:
  839. (tuple): tuple containing:
  840. x: decoded token score before softmax (batch, maxlen_out, token)
  841. if use_output_layer is True,
  842. olens: (batch, )
  843. """
  844. tgt = ys_in_pad
  845. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  846. memory = hs_pad
  847. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  848. x = tgt
  849. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  850. x, tgt_mask, memory, memory_mask
  851. )
  852. if self.decoders2 is not None:
  853. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  854. x, tgt_mask, memory, memory_mask
  855. )
  856. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  857. x, tgt_mask, memory, memory_mask
  858. )
  859. if self.normalize_before:
  860. x = self.after_norm(x)
  861. if self.output_layer is not None:
  862. x = self.output_layer(x)
  863. olens = tgt_mask.sum(1)
  864. return x, olens
  865. def score(self, ys, state, x):
  866. """Score."""
  867. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  868. logp, state = self.forward_one_step(
  869. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  870. )
  871. return logp.squeeze(0), state
  872. def forward_one_step(
  873. self,
  874. tgt: torch.Tensor,
  875. tgt_mask: torch.Tensor,
  876. memory: torch.Tensor,
  877. cache: List[torch.Tensor] = None,
  878. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  879. """Forward one step.
  880. Args:
  881. tgt: input token ids, int64 (batch, maxlen_out)
  882. tgt_mask: input token mask, (batch, maxlen_out)
  883. dtype=torch.uint8 in PyTorch 1.2-
  884. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  885. memory: encoded memory, float32 (batch, maxlen_in, feat)
  886. cache: cached output list of (batch, max_time_out-1, size)
  887. Returns:
  888. y, cache: NN output value and cache per `self.decoders`.
  889. y.shape` is (batch, maxlen_out, token)
  890. """
  891. x = self.embed(tgt)
  892. if cache is None:
  893. cache_layer_num = len(self.decoders)
  894. if self.decoders2 is not None:
  895. cache_layer_num += len(self.decoders2)
  896. cache = [None] * cache_layer_num
  897. new_cache = []
  898. # for c, decoder in zip(cache, self.decoders):
  899. for i in range(self.att_layer_num):
  900. decoder = self.decoders[i]
  901. c = cache[i]
  902. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  903. x, tgt_mask, memory, None, cache=c
  904. )
  905. new_cache.append(c_ret)
  906. if self.num_blocks - self.att_layer_num > 1:
  907. for i in range(self.num_blocks - self.att_layer_num):
  908. j = i + self.att_layer_num
  909. decoder = self.decoders2[i]
  910. c = cache[j]
  911. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  912. x, tgt_mask, memory, None, cache=c
  913. )
  914. new_cache.append(c_ret)
  915. for decoder in self.decoders3:
  916. x, tgt_mask, memory, memory_mask, _ = decoder(
  917. x, tgt_mask, memory, None, cache=None
  918. )
  919. if self.normalize_before:
  920. y = self.after_norm(x[:, -1])
  921. else:
  922. y = x[:, -1]
  923. if self.output_layer is not None:
  924. y = torch.log_softmax(self.output_layer(y), dim=-1)
  925. return y, new_cache
  926. def gen_tf2torch_map_dict(self):
  927. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  928. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  929. map_dict_local = {
  930. ## decoder
  931. # ffn
  932. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  933. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  934. "squeeze": None,
  935. "transpose": None,
  936. }, # (256,),(256,)
  937. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  938. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  939. "squeeze": None,
  940. "transpose": None,
  941. }, # (256,),(256,)
  942. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  943. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  944. "squeeze": 0,
  945. "transpose": (1, 0),
  946. }, # (1024,256),(1,256,1024)
  947. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  948. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  949. "squeeze": None,
  950. "transpose": None,
  951. }, # (1024,),(1024,)
  952. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  953. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  954. "squeeze": None,
  955. "transpose": None,
  956. }, # (1024,),(1024,)
  957. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  958. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  959. "squeeze": None,
  960. "transpose": None,
  961. }, # (1024,),(1024,)
  962. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  963. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  964. "squeeze": 0,
  965. "transpose": (1, 0),
  966. }, # (256,1024),(1,1024,256)
  967. # fsmn
  968. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  969. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  970. tensor_name_prefix_tf),
  971. "squeeze": None,
  972. "transpose": None,
  973. }, # (256,),(256,)
  974. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  975. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  976. tensor_name_prefix_tf),
  977. "squeeze": None,
  978. "transpose": None,
  979. }, # (256,),(256,)
  980. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  981. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  982. tensor_name_prefix_tf),
  983. "squeeze": 0,
  984. "transpose": (1, 2, 0),
  985. }, # (256,1,31),(1,31,256,1)
  986. # src att
  987. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  988. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  989. "squeeze": None,
  990. "transpose": None,
  991. }, # (256,),(256,)
  992. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  993. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  994. "squeeze": None,
  995. "transpose": None,
  996. }, # (256,),(256,)
  997. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  998. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  999. "squeeze": 0,
  1000. "transpose": (1, 0),
  1001. }, # (256,256),(1,256,256)
  1002. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  1003. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  1004. "squeeze": None,
  1005. "transpose": None,
  1006. }, # (256,),(256,)
  1007. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  1008. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1009. "squeeze": 0,
  1010. "transpose": (1, 0),
  1011. }, # (1024,256),(1,256,1024)
  1012. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  1013. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  1014. "squeeze": None,
  1015. "transpose": None,
  1016. }, # (1024,),(1024,)
  1017. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  1018. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  1019. "squeeze": 0,
  1020. "transpose": (1, 0),
  1021. }, # (256,256),(1,256,256)
  1022. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  1023. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  1024. "squeeze": None,
  1025. "transpose": None,
  1026. }, # (256,),(256,)
  1027. # dnn
  1028. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1029. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1030. "squeeze": None,
  1031. "transpose": None,
  1032. }, # (256,),(256,)
  1033. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1034. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  1035. "squeeze": None,
  1036. "transpose": None,
  1037. }, # (256,),(256,)
  1038. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1039. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  1040. "squeeze": 0,
  1041. "transpose": (1, 0),
  1042. }, # (1024,256),(1,256,1024)
  1043. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1044. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  1045. "squeeze": None,
  1046. "transpose": None,
  1047. }, # (1024,),(1024,)
  1048. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1049. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1050. "squeeze": None,
  1051. "transpose": None,
  1052. }, # (1024,),(1024,)
  1053. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1054. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1055. "squeeze": None,
  1056. "transpose": None,
  1057. }, # (1024,),(1024,)
  1058. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1059. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1060. "squeeze": 0,
  1061. "transpose": (1, 0),
  1062. }, # (256,1024),(1,1024,256)
  1063. # embed_concat_ffn
  1064. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1065. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1066. "squeeze": None,
  1067. "transpose": None,
  1068. }, # (256,),(256,)
  1069. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1070. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  1071. "squeeze": None,
  1072. "transpose": None,
  1073. }, # (256,),(256,)
  1074. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1075. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  1076. "squeeze": 0,
  1077. "transpose": (1, 0),
  1078. }, # (1024,256),(1,256,1024)
  1079. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1080. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  1081. "squeeze": None,
  1082. "transpose": None,
  1083. }, # (1024,),(1024,)
  1084. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1085. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1086. "squeeze": None,
  1087. "transpose": None,
  1088. }, # (1024,),(1024,)
  1089. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1090. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1091. "squeeze": None,
  1092. "transpose": None,
  1093. }, # (1024,),(1024,)
  1094. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1095. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1096. "squeeze": 0,
  1097. "transpose": (1, 0),
  1098. }, # (256,1024),(1,1024,256)
  1099. # out norm
  1100. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  1101. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1102. "squeeze": None,
  1103. "transpose": None,
  1104. }, # (256,),(256,)
  1105. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  1106. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  1107. "squeeze": None,
  1108. "transpose": None,
  1109. }, # (256,),(256,)
  1110. # in embed
  1111. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  1112. {"name": "{}/w_embs".format(tensor_name_prefix_tf),
  1113. "squeeze": None,
  1114. "transpose": None,
  1115. }, # (4235,256),(4235,256)
  1116. # out layer
  1117. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  1118. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
  1119. "squeeze": [None, None],
  1120. "transpose": [(1, 0), None],
  1121. }, # (4235,256),(256,4235)
  1122. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  1123. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  1124. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  1125. "squeeze": [None, None],
  1126. "transpose": [None, None],
  1127. }, # (4235,),(4235,)
  1128. }
  1129. return map_dict_local
  1130. def convert_tf2torch(self,
  1131. var_dict_tf,
  1132. var_dict_torch,
  1133. ):
  1134. map_dict = self.gen_tf2torch_map_dict()
  1135. var_dict_torch_update = dict()
  1136. decoder_layeridx_sets = set()
  1137. for name in sorted(var_dict_torch.keys(), reverse=False):
  1138. names = name.split('.')
  1139. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  1140. if names[1] == "decoders":
  1141. layeridx = int(names[2])
  1142. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1143. layeridx_bias = 0
  1144. layeridx += layeridx_bias
  1145. decoder_layeridx_sets.add(layeridx)
  1146. if name_q in map_dict.keys():
  1147. name_v = map_dict[name_q]["name"]
  1148. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1149. data_tf = var_dict_tf[name_tf]
  1150. if map_dict[name_q]["squeeze"] is not None:
  1151. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1152. if map_dict[name_q]["transpose"] is not None:
  1153. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1154. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1155. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1156. var_dict_torch[
  1157. name].size(),
  1158. data_tf.size())
  1159. var_dict_torch_update[name] = data_tf
  1160. logging.info(
  1161. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1162. var_dict_tf[name_tf].shape))
  1163. elif names[1] == "decoders2":
  1164. layeridx = int(names[2])
  1165. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1166. name_q = name_q.replace("decoders2", "decoders")
  1167. layeridx_bias = len(decoder_layeridx_sets)
  1168. layeridx += layeridx_bias
  1169. if "decoders." in name:
  1170. decoder_layeridx_sets.add(layeridx)
  1171. if name_q in map_dict.keys():
  1172. name_v = map_dict[name_q]["name"]
  1173. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1174. data_tf = var_dict_tf[name_tf]
  1175. if map_dict[name_q]["squeeze"] is not None:
  1176. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1177. if map_dict[name_q]["transpose"] is not None:
  1178. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1179. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1180. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1181. var_dict_torch[
  1182. name].size(),
  1183. data_tf.size())
  1184. var_dict_torch_update[name] = data_tf
  1185. logging.info(
  1186. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1187. var_dict_tf[name_tf].shape))
  1188. elif names[1] == "decoders3":
  1189. layeridx = int(names[2])
  1190. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1191. layeridx_bias = 0
  1192. layeridx += layeridx_bias
  1193. if "decoders." in name:
  1194. decoder_layeridx_sets.add(layeridx)
  1195. if name_q in map_dict.keys():
  1196. name_v = map_dict[name_q]["name"]
  1197. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1198. data_tf = var_dict_tf[name_tf]
  1199. if map_dict[name_q]["squeeze"] is not None:
  1200. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1201. if map_dict[name_q]["transpose"] is not None:
  1202. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1203. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1204. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1205. var_dict_torch[
  1206. name].size(),
  1207. data_tf.size())
  1208. var_dict_torch_update[name] = data_tf
  1209. logging.info(
  1210. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1211. var_dict_tf[name_tf].shape))
  1212. elif names[1] == "embed" or names[1] == "output_layer":
  1213. name_tf = map_dict[name]["name"]
  1214. if isinstance(name_tf, list):
  1215. idx_list = 0
  1216. if name_tf[idx_list] in var_dict_tf.keys():
  1217. pass
  1218. else:
  1219. idx_list = 1
  1220. data_tf = var_dict_tf[name_tf[idx_list]]
  1221. if map_dict[name]["squeeze"][idx_list] is not None:
  1222. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  1223. if map_dict[name]["transpose"][idx_list] is not None:
  1224. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  1225. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1226. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1227. var_dict_torch[
  1228. name].size(),
  1229. data_tf.size())
  1230. var_dict_torch_update[name] = data_tf
  1231. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  1232. name_tf[idx_list],
  1233. var_dict_tf[name_tf[
  1234. idx_list]].shape))
  1235. else:
  1236. data_tf = var_dict_tf[name_tf]
  1237. if map_dict[name]["squeeze"] is not None:
  1238. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1239. if map_dict[name]["transpose"] is not None:
  1240. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1241. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1242. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1243. var_dict_torch[
  1244. name].size(),
  1245. data_tf.size())
  1246. var_dict_torch_update[name] = data_tf
  1247. logging.info(
  1248. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1249. var_dict_tf[name_tf].shape))
  1250. elif names[1] == "after_norm":
  1251. name_tf = map_dict[name]["name"]
  1252. data_tf = var_dict_tf[name_tf]
  1253. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1254. var_dict_torch_update[name] = data_tf
  1255. logging.info(
  1256. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1257. var_dict_tf[name_tf].shape))
  1258. elif names[1] == "embed_concat_ffn":
  1259. layeridx = int(names[2])
  1260. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1261. layeridx_bias = 0
  1262. layeridx += layeridx_bias
  1263. if "decoders." in name:
  1264. decoder_layeridx_sets.add(layeridx)
  1265. if name_q in map_dict.keys():
  1266. name_v = map_dict[name_q]["name"]
  1267. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1268. data_tf = var_dict_tf[name_tf]
  1269. if map_dict[name_q]["squeeze"] is not None:
  1270. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1271. if map_dict[name_q]["transpose"] is not None:
  1272. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1273. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1274. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1275. var_dict_torch[
  1276. name].size(),
  1277. data_tf.size())
  1278. var_dict_torch_update[name] = data_tf
  1279. logging.info(
  1280. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1281. var_dict_tf[name_tf].shape))
  1282. return var_dict_torch_update