sanm_decoder.py 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489
  1. from typing import List
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from funasr.modules.streaming_utils import utils as myutils
  8. from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
  9. from typeguard import check_argument_types
  10. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  11. from funasr.modules.embedding import PositionalEncoding
  12. from funasr.modules.layer_norm import LayerNorm
  13. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  14. from funasr.modules.repeat import repeat
  15. class DecoderLayerSANM(nn.Module):
  16. """Single decoder layer module.
  17. Args:
  18. size (int): Input dimension.
  19. self_attn (torch.nn.Module): Self-attention module instance.
  20. `MultiHeadedAttention` instance can be used as the argument.
  21. src_attn (torch.nn.Module): Self-attention module instance.
  22. `MultiHeadedAttention` instance can be used as the argument.
  23. feed_forward (torch.nn.Module): Feed-forward module instance.
  24. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  25. can be used as the argument.
  26. dropout_rate (float): Dropout rate.
  27. normalize_before (bool): Whether to use layer_norm before the first block.
  28. concat_after (bool): Whether to concat attention layer's input and output.
  29. if True, additional linear will be applied.
  30. i.e. x -> x + linear(concat(x, att(x)))
  31. if False, no additional linear will be applied. i.e. x -> x + att(x)
  32. """
  33. def __init__(
  34. self,
  35. size,
  36. self_attn,
  37. src_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. ):
  43. """Construct an DecoderLayer object."""
  44. super(DecoderLayerSANM, self).__init__()
  45. self.size = size
  46. self.self_attn = self_attn
  47. self.src_attn = src_attn
  48. self.feed_forward = feed_forward
  49. self.norm1 = LayerNorm(size)
  50. if self_attn is not None:
  51. self.norm2 = LayerNorm(size)
  52. if src_attn is not None:
  53. self.norm3 = LayerNorm(size)
  54. self.dropout = nn.Dropout(dropout_rate)
  55. self.normalize_before = normalize_before
  56. self.concat_after = concat_after
  57. if self.concat_after:
  58. self.concat_linear1 = nn.Linear(size + size, size)
  59. self.concat_linear2 = nn.Linear(size + size, size)
  60. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  61. """Compute decoded features.
  62. Args:
  63. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  64. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  65. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  66. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  67. cache (List[torch.Tensor]): List of cached tensors.
  68. Each tensor shape should be (#batch, maxlen_out - 1, size).
  69. Returns:
  70. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  71. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  72. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  73. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  74. """
  75. # tgt = self.dropout(tgt)
  76. residual = tgt
  77. if self.normalize_before:
  78. tgt = self.norm1(tgt)
  79. tgt = self.feed_forward(tgt)
  80. x = tgt
  81. if self.self_attn:
  82. if self.normalize_before:
  83. tgt = self.norm2(tgt)
  84. x, _ = self.self_attn(tgt, tgt_mask)
  85. x = residual + self.dropout(x)
  86. if self.src_attn is not None:
  87. residual = x
  88. if self.normalize_before:
  89. x = self.norm3(x)
  90. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  91. return x, tgt_mask, memory, memory_mask, cache
  92. def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  93. """Compute decoded features.
  94. Args:
  95. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  96. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  97. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  98. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  99. cache (List[torch.Tensor]): List of cached tensors.
  100. Each tensor shape should be (#batch, maxlen_out - 1, size).
  101. Returns:
  102. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  103. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  104. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  105. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  106. """
  107. # tgt = self.dropout(tgt)
  108. residual = tgt
  109. if self.normalize_before:
  110. tgt = self.norm1(tgt)
  111. tgt = self.feed_forward(tgt)
  112. x = tgt
  113. if self.self_attn:
  114. if self.normalize_before:
  115. tgt = self.norm2(tgt)
  116. if self.training:
  117. cache = None
  118. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  119. x = residual + self.dropout(x)
  120. if self.src_attn is not None:
  121. residual = x
  122. if self.normalize_before:
  123. x = self.norm3(x)
  124. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  125. return x, tgt_mask, memory, memory_mask, cache
  126. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  127. """
  128. Author: Speech Lab of DAMO Academy, Alibaba Group
  129. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  130. https://arxiv.org/abs/2006.01713
  131. """
  132. def __init__(
  133. self,
  134. vocab_size: int,
  135. encoder_output_size: int,
  136. attention_heads: int = 4,
  137. linear_units: int = 2048,
  138. num_blocks: int = 6,
  139. dropout_rate: float = 0.1,
  140. positional_dropout_rate: float = 0.1,
  141. self_attention_dropout_rate: float = 0.0,
  142. src_attention_dropout_rate: float = 0.0,
  143. input_layer: str = "embed",
  144. use_output_layer: bool = True,
  145. pos_enc_class=PositionalEncoding,
  146. normalize_before: bool = True,
  147. concat_after: bool = False,
  148. att_layer_num: int = 6,
  149. kernel_size: int = 21,
  150. sanm_shfit: int = None,
  151. concat_embeds: bool = False,
  152. attention_dim: int = None,
  153. tf2torch_tensor_name_prefix_torch: str = "decoder",
  154. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  155. embed_tensor_name_prefix_tf: str = None,
  156. ):
  157. assert check_argument_types()
  158. super().__init__(
  159. vocab_size=vocab_size,
  160. encoder_output_size=encoder_output_size,
  161. dropout_rate=dropout_rate,
  162. positional_dropout_rate=positional_dropout_rate,
  163. input_layer=input_layer,
  164. use_output_layer=use_output_layer,
  165. pos_enc_class=pos_enc_class,
  166. normalize_before=normalize_before,
  167. )
  168. if attention_dim is None:
  169. attention_dim = encoder_output_size
  170. if input_layer == "embed":
  171. self.embed = torch.nn.Sequential(
  172. torch.nn.Embedding(vocab_size, attention_dim),
  173. )
  174. elif input_layer == "linear":
  175. self.embed = torch.nn.Sequential(
  176. torch.nn.Linear(vocab_size, attention_dim),
  177. torch.nn.LayerNorm(attention_dim),
  178. torch.nn.Dropout(dropout_rate),
  179. torch.nn.ReLU(),
  180. pos_enc_class(attention_dim, positional_dropout_rate),
  181. )
  182. else:
  183. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  184. self.normalize_before = normalize_before
  185. if self.normalize_before:
  186. self.after_norm = LayerNorm(attention_dim)
  187. if use_output_layer:
  188. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  189. else:
  190. self.output_layer = None
  191. self.att_layer_num = att_layer_num
  192. self.num_blocks = num_blocks
  193. if sanm_shfit is None:
  194. sanm_shfit = (kernel_size - 1) // 2
  195. self.decoders = repeat(
  196. att_layer_num,
  197. lambda lnum: DecoderLayerSANM(
  198. attention_dim,
  199. MultiHeadedAttentionSANMDecoder(
  200. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  201. ),
  202. MultiHeadedAttentionCrossAtt(
  203. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  204. ),
  205. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  206. dropout_rate,
  207. normalize_before,
  208. concat_after,
  209. ),
  210. )
  211. if num_blocks - att_layer_num <= 0:
  212. self.decoders2 = None
  213. else:
  214. self.decoders2 = repeat(
  215. num_blocks - att_layer_num,
  216. lambda lnum: DecoderLayerSANM(
  217. attention_dim,
  218. MultiHeadedAttentionSANMDecoder(
  219. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  220. ),
  221. None,
  222. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  223. dropout_rate,
  224. normalize_before,
  225. concat_after,
  226. ),
  227. )
  228. self.decoders3 = repeat(
  229. 1,
  230. lambda lnum: DecoderLayerSANM(
  231. attention_dim,
  232. None,
  233. None,
  234. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  235. dropout_rate,
  236. normalize_before,
  237. concat_after,
  238. ),
  239. )
  240. if concat_embeds:
  241. self.embed_concat_ffn = repeat(
  242. 1,
  243. lambda lnum: DecoderLayerSANM(
  244. attention_dim + encoder_output_size,
  245. None,
  246. None,
  247. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  248. adim=attention_dim),
  249. dropout_rate,
  250. normalize_before,
  251. concat_after,
  252. ),
  253. )
  254. else:
  255. self.embed_concat_ffn = None
  256. self.concat_embeds = concat_embeds
  257. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  258. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  259. self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
  260. def forward(
  261. self,
  262. hs_pad: torch.Tensor,
  263. hlens: torch.Tensor,
  264. ys_in_pad: torch.Tensor,
  265. ys_in_lens: torch.Tensor,
  266. chunk_mask: torch.Tensor = None,
  267. pre_acoustic_embeds: torch.Tensor = None,
  268. ) -> Tuple[torch.Tensor, torch.Tensor]:
  269. """Forward decoder.
  270. Args:
  271. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  272. hlens: (batch)
  273. ys_in_pad:
  274. input token ids, int64 (batch, maxlen_out)
  275. if input_layer == "embed"
  276. input tensor (batch, maxlen_out, #mels) in the other cases
  277. ys_in_lens: (batch)
  278. Returns:
  279. (tuple): tuple containing:
  280. x: decoded token score before softmax (batch, maxlen_out, token)
  281. if use_output_layer is True,
  282. olens: (batch, )
  283. """
  284. tgt = ys_in_pad
  285. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  286. memory = hs_pad
  287. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  288. if chunk_mask is not None:
  289. memory_mask = memory_mask * chunk_mask
  290. if tgt_mask.size(1) != memory_mask.size(1):
  291. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  292. x = self.embed(tgt)
  293. if pre_acoustic_embeds is not None and self.concat_embeds:
  294. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  295. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  296. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  297. x, tgt_mask, memory, memory_mask
  298. )
  299. if self.decoders2 is not None:
  300. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  301. x, tgt_mask, memory, memory_mask
  302. )
  303. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  304. x, tgt_mask, memory, memory_mask
  305. )
  306. if self.normalize_before:
  307. x = self.after_norm(x)
  308. if self.output_layer is not None:
  309. x = self.output_layer(x)
  310. olens = tgt_mask.sum(1)
  311. return x, olens
  312. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  313. """Score."""
  314. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  315. logp, state = self.forward_one_step(
  316. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  317. cache=state
  318. )
  319. return logp.squeeze(0), state
  320. def forward_one_step(
  321. self,
  322. tgt: torch.Tensor,
  323. tgt_mask: torch.Tensor,
  324. memory: torch.Tensor,
  325. memory_mask: torch.Tensor = None,
  326. pre_acoustic_embeds: torch.Tensor = None,
  327. cache: List[torch.Tensor] = None,
  328. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  329. """Forward one step.
  330. Args:
  331. tgt: input token ids, int64 (batch, maxlen_out)
  332. tgt_mask: input token mask, (batch, maxlen_out)
  333. dtype=torch.uint8 in PyTorch 1.2-
  334. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  335. memory: encoded memory, float32 (batch, maxlen_in, feat)
  336. cache: cached output list of (batch, max_time_out-1, size)
  337. Returns:
  338. y, cache: NN output value and cache per `self.decoders`.
  339. y.shape` is (batch, maxlen_out, token)
  340. """
  341. x = tgt[:, -1:]
  342. tgt_mask = None
  343. x = self.embed(x)
  344. if pre_acoustic_embeds is not None and self.concat_embeds:
  345. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  346. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  347. if cache is None:
  348. cache_layer_num = len(self.decoders)
  349. if self.decoders2 is not None:
  350. cache_layer_num += len(self.decoders2)
  351. cache = [None] * cache_layer_num
  352. new_cache = []
  353. # for c, decoder in zip(cache, self.decoders):
  354. for i in range(self.att_layer_num):
  355. decoder = self.decoders[i]
  356. c = cache[i]
  357. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  358. x, tgt_mask, memory, memory_mask, cache=c
  359. )
  360. new_cache.append(c_ret)
  361. if self.num_blocks - self.att_layer_num >= 1:
  362. for i in range(self.num_blocks - self.att_layer_num):
  363. j = i + self.att_layer_num
  364. decoder = self.decoders2[i]
  365. c = cache[j]
  366. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  367. x, tgt_mask, memory, memory_mask, cache=c
  368. )
  369. new_cache.append(c_ret)
  370. for decoder in self.decoders3:
  371. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  372. x, tgt_mask, memory, None, cache=None
  373. )
  374. if self.normalize_before:
  375. y = self.after_norm(x[:, -1])
  376. else:
  377. y = x[:, -1]
  378. if self.output_layer is not None:
  379. y = self.output_layer(y)
  380. y = torch.log_softmax(y, dim=-1)
  381. return y, new_cache
  382. def gen_tf2torch_map_dict(self):
  383. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  384. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  385. embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
  386. map_dict_local = {
  387. ## decoder
  388. # ffn
  389. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  390. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  391. "squeeze": None,
  392. "transpose": None,
  393. }, # (256,),(256,)
  394. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  395. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  396. "squeeze": None,
  397. "transpose": None,
  398. }, # (256,),(256,)
  399. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  400. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  401. "squeeze": 0,
  402. "transpose": (1, 0),
  403. }, # (1024,256),(1,256,1024)
  404. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  405. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  406. "squeeze": None,
  407. "transpose": None,
  408. }, # (1024,),(1024,)
  409. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  410. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  411. "squeeze": None,
  412. "transpose": None,
  413. }, # (1024,),(1024,)
  414. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  415. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  416. "squeeze": None,
  417. "transpose": None,
  418. }, # (1024,),(1024,)
  419. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  420. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  421. "squeeze": 0,
  422. "transpose": (1, 0),
  423. }, # (256,1024),(1,1024,256)
  424. # fsmn
  425. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  426. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  427. tensor_name_prefix_tf),
  428. "squeeze": None,
  429. "transpose": None,
  430. }, # (256,),(256,)
  431. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  432. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  433. tensor_name_prefix_tf),
  434. "squeeze": None,
  435. "transpose": None,
  436. }, # (256,),(256,)
  437. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  438. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  439. tensor_name_prefix_tf),
  440. "squeeze": 0,
  441. "transpose": (1, 2, 0),
  442. }, # (256,1,31),(1,31,256,1)
  443. # src att
  444. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  445. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  446. "squeeze": None,
  447. "transpose": None,
  448. }, # (256,),(256,)
  449. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  450. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  451. "squeeze": None,
  452. "transpose": None,
  453. }, # (256,),(256,)
  454. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  455. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  456. "squeeze": 0,
  457. "transpose": (1, 0),
  458. }, # (256,256),(1,256,256)
  459. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  460. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  461. "squeeze": None,
  462. "transpose": None,
  463. }, # (256,),(256,)
  464. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  465. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  466. "squeeze": 0,
  467. "transpose": (1, 0),
  468. }, # (1024,256),(1,256,1024)
  469. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  470. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  471. "squeeze": None,
  472. "transpose": None,
  473. }, # (1024,),(1024,)
  474. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  475. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  476. "squeeze": 0,
  477. "transpose": (1, 0),
  478. }, # (256,256),(1,256,256)
  479. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  480. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  481. "squeeze": None,
  482. "transpose": None,
  483. }, # (256,),(256,)
  484. # dnn
  485. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  486. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  487. "squeeze": None,
  488. "transpose": None,
  489. }, # (256,),(256,)
  490. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  491. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  492. "squeeze": None,
  493. "transpose": None,
  494. }, # (256,),(256,)
  495. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  496. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  497. "squeeze": 0,
  498. "transpose": (1, 0),
  499. }, # (1024,256),(1,256,1024)
  500. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  501. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  502. "squeeze": None,
  503. "transpose": None,
  504. }, # (1024,),(1024,)
  505. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  506. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  507. "squeeze": None,
  508. "transpose": None,
  509. }, # (1024,),(1024,)
  510. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  511. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  512. "squeeze": None,
  513. "transpose": None,
  514. }, # (1024,),(1024,)
  515. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  516. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  517. "squeeze": 0,
  518. "transpose": (1, 0),
  519. }, # (256,1024),(1,1024,256)
  520. # embed_concat_ffn
  521. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  522. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  523. "squeeze": None,
  524. "transpose": None,
  525. }, # (256,),(256,)
  526. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  527. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  528. "squeeze": None,
  529. "transpose": None,
  530. }, # (256,),(256,)
  531. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  532. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  533. "squeeze": 0,
  534. "transpose": (1, 0),
  535. }, # (1024,256),(1,256,1024)
  536. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  537. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  538. "squeeze": None,
  539. "transpose": None,
  540. }, # (1024,),(1024,)
  541. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  542. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  543. "squeeze": None,
  544. "transpose": None,
  545. }, # (1024,),(1024,)
  546. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  547. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  548. "squeeze": None,
  549. "transpose": None,
  550. }, # (1024,),(1024,)
  551. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  552. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  553. "squeeze": 0,
  554. "transpose": (1, 0),
  555. }, # (256,1024),(1,1024,256)
  556. # out norm
  557. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  558. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  559. "squeeze": None,
  560. "transpose": None,
  561. }, # (256,),(256,)
  562. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  563. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  564. "squeeze": None,
  565. "transpose": None,
  566. }, # (256,),(256,)
  567. # in embed
  568. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  569. {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
  570. "squeeze": None,
  571. "transpose": None,
  572. }, # (4235,256),(4235,256)
  573. # out layer
  574. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  575. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
  576. "{}/w_embs".format(embed_tensor_name_prefix_tf)],
  577. "squeeze": [None, None],
  578. "transpose": [(1, 0), None],
  579. }, # (4235,256),(256,4235)
  580. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  581. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  582. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  583. "squeeze": [None, None],
  584. "transpose": [None, None],
  585. }, # (4235,),(4235,)
  586. }
  587. return map_dict_local
  588. def convert_tf2torch(self,
  589. var_dict_tf,
  590. var_dict_torch,
  591. ):
  592. map_dict = self.gen_tf2torch_map_dict()
  593. var_dict_torch_update = dict()
  594. decoder_layeridx_sets = set()
  595. for name in sorted(var_dict_torch.keys(), reverse=False):
  596. names = name.split('.')
  597. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  598. if names[1] == "decoders":
  599. layeridx = int(names[2])
  600. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  601. layeridx_bias = 0
  602. layeridx += layeridx_bias
  603. decoder_layeridx_sets.add(layeridx)
  604. if name_q in map_dict.keys():
  605. name_v = map_dict[name_q]["name"]
  606. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  607. data_tf = var_dict_tf[name_tf]
  608. if map_dict[name_q]["squeeze"] is not None:
  609. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  610. if map_dict[name_q]["transpose"] is not None:
  611. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  612. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  613. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  614. var_dict_torch[
  615. name].size(),
  616. data_tf.size())
  617. var_dict_torch_update[name] = data_tf
  618. logging.info(
  619. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  620. var_dict_tf[name_tf].shape))
  621. elif names[1] == "decoders2":
  622. layeridx = int(names[2])
  623. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  624. name_q = name_q.replace("decoders2", "decoders")
  625. layeridx_bias = len(decoder_layeridx_sets)
  626. layeridx += layeridx_bias
  627. if "decoders." in name:
  628. decoder_layeridx_sets.add(layeridx)
  629. if name_q in map_dict.keys():
  630. name_v = map_dict[name_q]["name"]
  631. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  632. data_tf = var_dict_tf[name_tf]
  633. if map_dict[name_q]["squeeze"] is not None:
  634. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  635. if map_dict[name_q]["transpose"] is not None:
  636. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  637. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  638. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  639. var_dict_torch[
  640. name].size(),
  641. data_tf.size())
  642. var_dict_torch_update[name] = data_tf
  643. logging.info(
  644. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  645. var_dict_tf[name_tf].shape))
  646. elif names[1] == "decoders3":
  647. layeridx = int(names[2])
  648. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  649. layeridx_bias = 0
  650. layeridx += layeridx_bias
  651. if "decoders." in name:
  652. decoder_layeridx_sets.add(layeridx)
  653. if name_q in map_dict.keys():
  654. name_v = map_dict[name_q]["name"]
  655. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  656. data_tf = var_dict_tf[name_tf]
  657. if map_dict[name_q]["squeeze"] is not None:
  658. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  659. if map_dict[name_q]["transpose"] is not None:
  660. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  661. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  662. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  663. var_dict_torch[
  664. name].size(),
  665. data_tf.size())
  666. var_dict_torch_update[name] = data_tf
  667. logging.info(
  668. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  669. var_dict_tf[name_tf].shape))
  670. elif names[1] == "embed" or names[1] == "output_layer":
  671. name_tf = map_dict[name]["name"]
  672. if isinstance(name_tf, list):
  673. idx_list = 0
  674. if name_tf[idx_list] in var_dict_tf.keys():
  675. pass
  676. else:
  677. idx_list = 1
  678. data_tf = var_dict_tf[name_tf[idx_list]]
  679. if map_dict[name]["squeeze"][idx_list] is not None:
  680. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  681. if map_dict[name]["transpose"][idx_list] is not None:
  682. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  683. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  684. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  685. var_dict_torch[
  686. name].size(),
  687. data_tf.size())
  688. var_dict_torch_update[name] = data_tf
  689. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  690. name_tf[idx_list],
  691. var_dict_tf[name_tf[
  692. idx_list]].shape))
  693. else:
  694. data_tf = var_dict_tf[name_tf]
  695. if map_dict[name]["squeeze"] is not None:
  696. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  697. if map_dict[name]["transpose"] is not None:
  698. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  699. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  700. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  701. var_dict_torch[
  702. name].size(),
  703. data_tf.size())
  704. var_dict_torch_update[name] = data_tf
  705. logging.info(
  706. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  707. var_dict_tf[name_tf].shape))
  708. elif names[1] == "after_norm":
  709. name_tf = map_dict[name]["name"]
  710. data_tf = var_dict_tf[name_tf]
  711. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  712. var_dict_torch_update[name] = data_tf
  713. logging.info(
  714. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  715. var_dict_tf[name_tf].shape))
  716. elif names[1] == "embed_concat_ffn":
  717. layeridx = int(names[2])
  718. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  719. layeridx_bias = 0
  720. layeridx += layeridx_bias
  721. if "decoders." in name:
  722. decoder_layeridx_sets.add(layeridx)
  723. if name_q in map_dict.keys():
  724. name_v = map_dict[name_q]["name"]
  725. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  726. data_tf = var_dict_tf[name_tf]
  727. if map_dict[name_q]["squeeze"] is not None:
  728. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  729. if map_dict[name_q]["transpose"] is not None:
  730. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  731. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  732. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  733. var_dict_torch[
  734. name].size(),
  735. data_tf.size())
  736. var_dict_torch_update[name] = data_tf
  737. logging.info(
  738. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  739. var_dict_tf[name_tf].shape))
  740. return var_dict_torch_update
  741. class ParaformerSANMDecoder(BaseTransformerDecoder):
  742. """
  743. Author: Speech Lab of DAMO Academy, Alibaba Group
  744. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  745. https://arxiv.org/abs/2006.01713
  746. """
  747. def __init__(
  748. self,
  749. vocab_size: int,
  750. encoder_output_size: int,
  751. attention_heads: int = 4,
  752. linear_units: int = 2048,
  753. num_blocks: int = 6,
  754. dropout_rate: float = 0.1,
  755. positional_dropout_rate: float = 0.1,
  756. self_attention_dropout_rate: float = 0.0,
  757. src_attention_dropout_rate: float = 0.0,
  758. input_layer: str = "embed",
  759. use_output_layer: bool = True,
  760. pos_enc_class=PositionalEncoding,
  761. normalize_before: bool = True,
  762. concat_after: bool = False,
  763. att_layer_num: int = 6,
  764. kernel_size: int = 21,
  765. sanm_shfit: int = 0,
  766. tf2torch_tensor_name_prefix_torch: str = "decoder",
  767. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  768. ):
  769. assert check_argument_types()
  770. super().__init__(
  771. vocab_size=vocab_size,
  772. encoder_output_size=encoder_output_size,
  773. dropout_rate=dropout_rate,
  774. positional_dropout_rate=positional_dropout_rate,
  775. input_layer=input_layer,
  776. use_output_layer=use_output_layer,
  777. pos_enc_class=pos_enc_class,
  778. normalize_before=normalize_before,
  779. )
  780. attention_dim = encoder_output_size
  781. if input_layer == "embed":
  782. self.embed = torch.nn.Sequential(
  783. torch.nn.Embedding(vocab_size, attention_dim),
  784. # pos_enc_class(attention_dim, positional_dropout_rate),
  785. )
  786. elif input_layer == "linear":
  787. self.embed = torch.nn.Sequential(
  788. torch.nn.Linear(vocab_size, attention_dim),
  789. torch.nn.LayerNorm(attention_dim),
  790. torch.nn.Dropout(dropout_rate),
  791. torch.nn.ReLU(),
  792. pos_enc_class(attention_dim, positional_dropout_rate),
  793. )
  794. else:
  795. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  796. self.normalize_before = normalize_before
  797. if self.normalize_before:
  798. self.after_norm = LayerNorm(attention_dim)
  799. if use_output_layer:
  800. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  801. else:
  802. self.output_layer = None
  803. self.att_layer_num = att_layer_num
  804. self.num_blocks = num_blocks
  805. if sanm_shfit is None:
  806. sanm_shfit = (kernel_size - 1) // 2
  807. self.decoders = repeat(
  808. att_layer_num,
  809. lambda lnum: DecoderLayerSANM(
  810. attention_dim,
  811. MultiHeadedAttentionSANMDecoder(
  812. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  813. ),
  814. MultiHeadedAttentionCrossAtt(
  815. attention_heads, attention_dim, src_attention_dropout_rate
  816. ),
  817. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  818. dropout_rate,
  819. normalize_before,
  820. concat_after,
  821. ),
  822. )
  823. if num_blocks - att_layer_num <= 0:
  824. self.decoders2 = None
  825. else:
  826. self.decoders2 = repeat(
  827. num_blocks - att_layer_num,
  828. lambda lnum: DecoderLayerSANM(
  829. attention_dim,
  830. MultiHeadedAttentionSANMDecoder(
  831. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  832. ),
  833. None,
  834. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  835. dropout_rate,
  836. normalize_before,
  837. concat_after,
  838. ),
  839. )
  840. self.decoders3 = repeat(
  841. 1,
  842. lambda lnum: DecoderLayerSANM(
  843. attention_dim,
  844. None,
  845. None,
  846. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  847. dropout_rate,
  848. normalize_before,
  849. concat_after,
  850. ),
  851. )
  852. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  853. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  854. def forward(
  855. self,
  856. hs_pad: torch.Tensor,
  857. hlens: torch.Tensor,
  858. ys_in_pad: torch.Tensor,
  859. ys_in_lens: torch.Tensor,
  860. chunk_mask: torch.Tensor = None,
  861. ) -> Tuple[torch.Tensor, torch.Tensor]:
  862. """Forward decoder.
  863. Args:
  864. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  865. hlens: (batch)
  866. ys_in_pad:
  867. input token ids, int64 (batch, maxlen_out)
  868. if input_layer == "embed"
  869. input tensor (batch, maxlen_out, #mels) in the other cases
  870. ys_in_lens: (batch)
  871. Returns:
  872. (tuple): tuple containing:
  873. x: decoded token score before softmax (batch, maxlen_out, token)
  874. if use_output_layer is True,
  875. olens: (batch, )
  876. """
  877. tgt = ys_in_pad
  878. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  879. memory = hs_pad
  880. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  881. if chunk_mask is not None:
  882. memory_mask = memory_mask * chunk_mask
  883. if tgt_mask.size(1) != memory_mask.size(1):
  884. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  885. x = tgt
  886. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  887. x, tgt_mask, memory, memory_mask
  888. )
  889. if self.decoders2 is not None:
  890. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  891. x, tgt_mask, memory, memory_mask
  892. )
  893. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  894. x, tgt_mask, memory, memory_mask
  895. )
  896. if self.normalize_before:
  897. x = self.after_norm(x)
  898. if self.output_layer is not None:
  899. x = self.output_layer(x)
  900. olens = tgt_mask.sum(1)
  901. return x, olens
  902. def score(self, ys, state, x):
  903. """Score."""
  904. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  905. logp, state = self.forward_one_step(
  906. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  907. )
  908. return logp.squeeze(0), state
  909. def forward_chunk(
  910. self,
  911. memory: torch.Tensor,
  912. tgt: torch.Tensor,
  913. cache: dict = None,
  914. ) -> Tuple[torch.Tensor, torch.Tensor]:
  915. """Forward decoder.
  916. Args:
  917. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  918. hlens: (batch)
  919. ys_in_pad:
  920. input token ids, int64 (batch, maxlen_out)
  921. if input_layer == "embed"
  922. input tensor (batch, maxlen_out, #mels) in the other cases
  923. ys_in_lens: (batch)
  924. Returns:
  925. (tuple): tuple containing:
  926. x: decoded token score before softmax (batch, maxlen_out, token)
  927. if use_output_layer is True,
  928. olens: (batch, )
  929. """
  930. x = tgt
  931. if cache["decode_fsmn"] is None:
  932. cache_layer_num = len(self.decoders)
  933. if self.decoders2 is not None:
  934. cache_layer_num += len(self.decoders2)
  935. new_cache = [None] * cache_layer_num
  936. else:
  937. new_cache = cache["decode_fsmn"]
  938. for i in range(self.att_layer_num):
  939. decoder = self.decoders[i]
  940. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  941. x, None, memory, None, cache=new_cache[i]
  942. )
  943. new_cache[i] = c_ret
  944. if self.num_blocks - self.att_layer_num > 1:
  945. for i in range(self.num_blocks - self.att_layer_num):
  946. j = i + self.att_layer_num
  947. decoder = self.decoders2[i]
  948. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  949. x, None, memory, None, cache=new_cache[j]
  950. )
  951. new_cache[j] = c_ret
  952. for decoder in self.decoders3:
  953. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  954. x, None, memory, None, cache=None
  955. )
  956. if self.normalize_before:
  957. x = self.after_norm(x)
  958. if self.output_layer is not None:
  959. x = self.output_layer(x)
  960. cache["decode_fsmn"] = new_cache
  961. return x
  962. def forward_one_step(
  963. self,
  964. tgt: torch.Tensor,
  965. tgt_mask: torch.Tensor,
  966. memory: torch.Tensor,
  967. cache: List[torch.Tensor] = None,
  968. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  969. """Forward one step.
  970. Args:
  971. tgt: input token ids, int64 (batch, maxlen_out)
  972. tgt_mask: input token mask, (batch, maxlen_out)
  973. dtype=torch.uint8 in PyTorch 1.2-
  974. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  975. memory: encoded memory, float32 (batch, maxlen_in, feat)
  976. cache: cached output list of (batch, max_time_out-1, size)
  977. Returns:
  978. y, cache: NN output value and cache per `self.decoders`.
  979. y.shape` is (batch, maxlen_out, token)
  980. """
  981. x = self.embed(tgt)
  982. if cache is None:
  983. cache_layer_num = len(self.decoders)
  984. if self.decoders2 is not None:
  985. cache_layer_num += len(self.decoders2)
  986. cache = [None] * cache_layer_num
  987. new_cache = []
  988. # for c, decoder in zip(cache, self.decoders):
  989. for i in range(self.att_layer_num):
  990. decoder = self.decoders[i]
  991. c = cache[i]
  992. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  993. x, tgt_mask, memory, None, cache=c
  994. )
  995. new_cache.append(c_ret)
  996. if self.num_blocks - self.att_layer_num > 1:
  997. for i in range(self.num_blocks - self.att_layer_num):
  998. j = i + self.att_layer_num
  999. decoder = self.decoders2[i]
  1000. c = cache[j]
  1001. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  1002. x, tgt_mask, memory, None, cache=c
  1003. )
  1004. new_cache.append(c_ret)
  1005. for decoder in self.decoders3:
  1006. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  1007. x, tgt_mask, memory, None, cache=None
  1008. )
  1009. if self.normalize_before:
  1010. y = self.after_norm(x[:, -1])
  1011. else:
  1012. y = x[:, -1]
  1013. if self.output_layer is not None:
  1014. y = torch.log_softmax(self.output_layer(y), dim=-1)
  1015. return y, new_cache
  1016. def gen_tf2torch_map_dict(self):
  1017. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  1018. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  1019. map_dict_local = {
  1020. ## decoder
  1021. # ffn
  1022. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1023. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1024. "squeeze": None,
  1025. "transpose": None,
  1026. }, # (256,),(256,)
  1027. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1028. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  1029. "squeeze": None,
  1030. "transpose": None,
  1031. }, # (256,),(256,)
  1032. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1033. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  1034. "squeeze": 0,
  1035. "transpose": (1, 0),
  1036. }, # (1024,256),(1,256,1024)
  1037. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1038. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  1039. "squeeze": None,
  1040. "transpose": None,
  1041. }, # (1024,),(1024,)
  1042. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1043. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1044. "squeeze": None,
  1045. "transpose": None,
  1046. }, # (1024,),(1024,)
  1047. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1048. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1049. "squeeze": None,
  1050. "transpose": None,
  1051. }, # (1024,),(1024,)
  1052. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1053. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1054. "squeeze": 0,
  1055. "transpose": (1, 0),
  1056. }, # (256,1024),(1,1024,256)
  1057. # fsmn
  1058. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  1059. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  1060. tensor_name_prefix_tf),
  1061. "squeeze": None,
  1062. "transpose": None,
  1063. }, # (256,),(256,)
  1064. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  1065. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  1066. tensor_name_prefix_tf),
  1067. "squeeze": None,
  1068. "transpose": None,
  1069. }, # (256,),(256,)
  1070. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  1071. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  1072. tensor_name_prefix_tf),
  1073. "squeeze": 0,
  1074. "transpose": (1, 2, 0),
  1075. }, # (256,1,31),(1,31,256,1)
  1076. # src att
  1077. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  1078. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1079. "squeeze": None,
  1080. "transpose": None,
  1081. }, # (256,),(256,)
  1082. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  1083. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  1084. "squeeze": None,
  1085. "transpose": None,
  1086. }, # (256,),(256,)
  1087. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  1088. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  1089. "squeeze": 0,
  1090. "transpose": (1, 0),
  1091. }, # (256,256),(1,256,256)
  1092. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  1093. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  1094. "squeeze": None,
  1095. "transpose": None,
  1096. }, # (256,),(256,)
  1097. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  1098. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1099. "squeeze": 0,
  1100. "transpose": (1, 0),
  1101. }, # (1024,256),(1,256,1024)
  1102. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  1103. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  1104. "squeeze": None,
  1105. "transpose": None,
  1106. }, # (1024,),(1024,)
  1107. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  1108. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  1109. "squeeze": 0,
  1110. "transpose": (1, 0),
  1111. }, # (256,256),(1,256,256)
  1112. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  1113. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  1114. "squeeze": None,
  1115. "transpose": None,
  1116. }, # (256,),(256,)
  1117. # dnn
  1118. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1119. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1120. "squeeze": None,
  1121. "transpose": None,
  1122. }, # (256,),(256,)
  1123. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1124. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  1125. "squeeze": None,
  1126. "transpose": None,
  1127. }, # (256,),(256,)
  1128. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1129. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  1130. "squeeze": 0,
  1131. "transpose": (1, 0),
  1132. }, # (1024,256),(1,256,1024)
  1133. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1134. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  1135. "squeeze": None,
  1136. "transpose": None,
  1137. }, # (1024,),(1024,)
  1138. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1139. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1140. "squeeze": None,
  1141. "transpose": None,
  1142. }, # (1024,),(1024,)
  1143. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1144. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1145. "squeeze": None,
  1146. "transpose": None,
  1147. }, # (1024,),(1024,)
  1148. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1149. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1150. "squeeze": 0,
  1151. "transpose": (1, 0),
  1152. }, # (256,1024),(1,1024,256)
  1153. # embed_concat_ffn
  1154. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1155. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1156. "squeeze": None,
  1157. "transpose": None,
  1158. }, # (256,),(256,)
  1159. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1160. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  1161. "squeeze": None,
  1162. "transpose": None,
  1163. }, # (256,),(256,)
  1164. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1165. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  1166. "squeeze": 0,
  1167. "transpose": (1, 0),
  1168. }, # (1024,256),(1,256,1024)
  1169. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1170. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  1171. "squeeze": None,
  1172. "transpose": None,
  1173. }, # (1024,),(1024,)
  1174. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1175. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1176. "squeeze": None,
  1177. "transpose": None,
  1178. }, # (1024,),(1024,)
  1179. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1180. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1181. "squeeze": None,
  1182. "transpose": None,
  1183. }, # (1024,),(1024,)
  1184. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1185. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1186. "squeeze": 0,
  1187. "transpose": (1, 0),
  1188. }, # (256,1024),(1,1024,256)
  1189. # out norm
  1190. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  1191. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1192. "squeeze": None,
  1193. "transpose": None,
  1194. }, # (256,),(256,)
  1195. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  1196. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  1197. "squeeze": None,
  1198. "transpose": None,
  1199. }, # (256,),(256,)
  1200. # in embed
  1201. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  1202. {"name": "{}/w_embs".format(tensor_name_prefix_tf),
  1203. "squeeze": None,
  1204. "transpose": None,
  1205. }, # (4235,256),(4235,256)
  1206. # out layer
  1207. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  1208. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
  1209. "squeeze": [None, None],
  1210. "transpose": [(1, 0), None],
  1211. }, # (4235,256),(256,4235)
  1212. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  1213. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  1214. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  1215. "squeeze": [None, None],
  1216. "transpose": [None, None],
  1217. }, # (4235,),(4235,)
  1218. }
  1219. return map_dict_local
  1220. def convert_tf2torch(self,
  1221. var_dict_tf,
  1222. var_dict_torch,
  1223. ):
  1224. map_dict = self.gen_tf2torch_map_dict()
  1225. var_dict_torch_update = dict()
  1226. decoder_layeridx_sets = set()
  1227. for name in sorted(var_dict_torch.keys(), reverse=False):
  1228. names = name.split('.')
  1229. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  1230. if names[1] == "decoders":
  1231. layeridx = int(names[2])
  1232. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1233. layeridx_bias = 0
  1234. layeridx += layeridx_bias
  1235. decoder_layeridx_sets.add(layeridx)
  1236. if name_q in map_dict.keys():
  1237. name_v = map_dict[name_q]["name"]
  1238. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1239. data_tf = var_dict_tf[name_tf]
  1240. if map_dict[name_q]["squeeze"] is not None:
  1241. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1242. if map_dict[name_q]["transpose"] is not None:
  1243. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1244. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1245. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1246. var_dict_torch[
  1247. name].size(),
  1248. data_tf.size())
  1249. var_dict_torch_update[name] = data_tf
  1250. logging.info(
  1251. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1252. var_dict_tf[name_tf].shape))
  1253. elif names[1] == "decoders2":
  1254. layeridx = int(names[2])
  1255. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1256. name_q = name_q.replace("decoders2", "decoders")
  1257. layeridx_bias = len(decoder_layeridx_sets)
  1258. layeridx += layeridx_bias
  1259. if "decoders." in name:
  1260. decoder_layeridx_sets.add(layeridx)
  1261. if name_q in map_dict.keys():
  1262. name_v = map_dict[name_q]["name"]
  1263. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1264. data_tf = var_dict_tf[name_tf]
  1265. if map_dict[name_q]["squeeze"] is not None:
  1266. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1267. if map_dict[name_q]["transpose"] is not None:
  1268. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1269. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1270. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1271. var_dict_torch[
  1272. name].size(),
  1273. data_tf.size())
  1274. var_dict_torch_update[name] = data_tf
  1275. logging.info(
  1276. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1277. var_dict_tf[name_tf].shape))
  1278. elif names[1] == "decoders3":
  1279. layeridx = int(names[2])
  1280. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1281. layeridx_bias = 0
  1282. layeridx += layeridx_bias
  1283. if "decoders." in name:
  1284. decoder_layeridx_sets.add(layeridx)
  1285. if name_q in map_dict.keys():
  1286. name_v = map_dict[name_q]["name"]
  1287. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1288. data_tf = var_dict_tf[name_tf]
  1289. if map_dict[name_q]["squeeze"] is not None:
  1290. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1291. if map_dict[name_q]["transpose"] is not None:
  1292. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1293. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1294. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1295. var_dict_torch[
  1296. name].size(),
  1297. data_tf.size())
  1298. var_dict_torch_update[name] = data_tf
  1299. logging.info(
  1300. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1301. var_dict_tf[name_tf].shape))
  1302. elif names[1] == "embed" or names[1] == "output_layer":
  1303. name_tf = map_dict[name]["name"]
  1304. if isinstance(name_tf, list):
  1305. idx_list = 0
  1306. if name_tf[idx_list] in var_dict_tf.keys():
  1307. pass
  1308. else:
  1309. idx_list = 1
  1310. data_tf = var_dict_tf[name_tf[idx_list]]
  1311. if map_dict[name]["squeeze"][idx_list] is not None:
  1312. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  1313. if map_dict[name]["transpose"][idx_list] is not None:
  1314. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  1315. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1316. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1317. var_dict_torch[
  1318. name].size(),
  1319. data_tf.size())
  1320. var_dict_torch_update[name] = data_tf
  1321. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  1322. name_tf[idx_list],
  1323. var_dict_tf[name_tf[
  1324. idx_list]].shape))
  1325. else:
  1326. data_tf = var_dict_tf[name_tf]
  1327. if map_dict[name]["squeeze"] is not None:
  1328. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1329. if map_dict[name]["transpose"] is not None:
  1330. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1331. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1332. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1333. var_dict_torch[
  1334. name].size(),
  1335. data_tf.size())
  1336. var_dict_torch_update[name] = data_tf
  1337. logging.info(
  1338. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1339. var_dict_tf[name_tf].shape))
  1340. elif names[1] == "after_norm":
  1341. name_tf = map_dict[name]["name"]
  1342. data_tf = var_dict_tf[name_tf]
  1343. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1344. var_dict_torch_update[name] = data_tf
  1345. logging.info(
  1346. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1347. var_dict_tf[name_tf].shape))
  1348. elif names[1] == "embed_concat_ffn":
  1349. layeridx = int(names[2])
  1350. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1351. layeridx_bias = 0
  1352. layeridx += layeridx_bias
  1353. if "decoders." in name:
  1354. decoder_layeridx_sets.add(layeridx)
  1355. if name_q in map_dict.keys():
  1356. name_v = map_dict[name_q]["name"]
  1357. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1358. data_tf = var_dict_tf[name_tf]
  1359. if map_dict[name_q]["squeeze"] is not None:
  1360. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1361. if map_dict[name_q]["transpose"] is not None:
  1362. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1363. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1364. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1365. var_dict_torch[
  1366. name].size(),
  1367. data_tf.size())
  1368. var_dict_torch_update[name] = data_tf
  1369. logging.info(
  1370. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1371. var_dict_tf[name_tf].shape))
  1372. return var_dict_torch_update