sanm_decoder.py 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490
  1. from typing import List
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from funasr.modules.streaming_utils import utils as myutils
  8. from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
  9. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  10. from funasr.modules.embedding import PositionalEncoding
  11. from funasr.modules.layer_norm import LayerNorm
  12. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  13. from funasr.modules.repeat import repeat
  14. class DecoderLayerSANM(nn.Module):
  15. """Single decoder layer module.
  16. Args:
  17. size (int): Input dimension.
  18. self_attn (torch.nn.Module): Self-attention module instance.
  19. `MultiHeadedAttention` instance can be used as the argument.
  20. src_attn (torch.nn.Module): Self-attention module instance.
  21. `MultiHeadedAttention` instance can be used as the argument.
  22. feed_forward (torch.nn.Module): Feed-forward module instance.
  23. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  24. can be used as the argument.
  25. dropout_rate (float): Dropout rate.
  26. normalize_before (bool): Whether to use layer_norm before the first block.
  27. concat_after (bool): Whether to concat attention layer's input and output.
  28. if True, additional linear will be applied.
  29. i.e. x -> x + linear(concat(x, att(x)))
  30. if False, no additional linear will be applied. i.e. x -> x + att(x)
  31. """
  32. def __init__(
  33. self,
  34. size,
  35. self_attn,
  36. src_attn,
  37. feed_forward,
  38. dropout_rate,
  39. normalize_before=True,
  40. concat_after=False,
  41. ):
  42. """Construct an DecoderLayer object."""
  43. super(DecoderLayerSANM, self).__init__()
  44. self.size = size
  45. self.self_attn = self_attn
  46. self.src_attn = src_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(size)
  49. if self_attn is not None:
  50. self.norm2 = LayerNorm(size)
  51. if src_attn is not None:
  52. self.norm3 = LayerNorm(size)
  53. self.dropout = nn.Dropout(dropout_rate)
  54. self.normalize_before = normalize_before
  55. self.concat_after = concat_after
  56. if self.concat_after:
  57. self.concat_linear1 = nn.Linear(size + size, size)
  58. self.concat_linear2 = nn.Linear(size + size, size)
  59. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  60. """Compute decoded features.
  61. Args:
  62. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  63. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  64. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  65. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  66. cache (List[torch.Tensor]): List of cached tensors.
  67. Each tensor shape should be (#batch, maxlen_out - 1, size).
  68. Returns:
  69. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  70. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  71. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  72. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  73. """
  74. # tgt = self.dropout(tgt)
  75. residual = tgt
  76. if self.normalize_before:
  77. tgt = self.norm1(tgt)
  78. tgt = self.feed_forward(tgt)
  79. x = tgt
  80. if self.self_attn:
  81. if self.normalize_before:
  82. tgt = self.norm2(tgt)
  83. x, _ = self.self_attn(tgt, tgt_mask)
  84. x = residual + self.dropout(x)
  85. if self.src_attn is not None:
  86. residual = x
  87. if self.normalize_before:
  88. x = self.norm3(x)
  89. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  90. return x, tgt_mask, memory, memory_mask, cache
  91. def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  92. """Compute decoded features.
  93. Args:
  94. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  95. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  96. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  97. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  98. cache (List[torch.Tensor]): List of cached tensors.
  99. Each tensor shape should be (#batch, maxlen_out - 1, size).
  100. Returns:
  101. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  102. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  103. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  104. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  105. """
  106. # tgt = self.dropout(tgt)
  107. residual = tgt
  108. if self.normalize_before:
  109. tgt = self.norm1(tgt)
  110. tgt = self.feed_forward(tgt)
  111. x = tgt
  112. if self.self_attn:
  113. if self.normalize_before:
  114. tgt = self.norm2(tgt)
  115. if self.training:
  116. cache = None
  117. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  118. x = residual + self.dropout(x)
  119. if self.src_attn is not None:
  120. residual = x
  121. if self.normalize_before:
  122. x = self.norm3(x)
  123. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  124. return x, tgt_mask, memory, memory_mask, cache
  125. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  126. """
  127. Author: Speech Lab of DAMO Academy, Alibaba Group
  128. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  129. https://arxiv.org/abs/2006.01713
  130. """
  131. def __init__(
  132. self,
  133. vocab_size: int,
  134. encoder_output_size: int,
  135. attention_heads: int = 4,
  136. linear_units: int = 2048,
  137. num_blocks: int = 6,
  138. dropout_rate: float = 0.1,
  139. positional_dropout_rate: float = 0.1,
  140. self_attention_dropout_rate: float = 0.0,
  141. src_attention_dropout_rate: float = 0.0,
  142. input_layer: str = "embed",
  143. use_output_layer: bool = True,
  144. pos_enc_class=PositionalEncoding,
  145. normalize_before: bool = True,
  146. concat_after: bool = False,
  147. att_layer_num: int = 6,
  148. kernel_size: int = 21,
  149. sanm_shfit: int = None,
  150. concat_embeds: bool = False,
  151. attention_dim: int = None,
  152. tf2torch_tensor_name_prefix_torch: str = "decoder",
  153. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  154. embed_tensor_name_prefix_tf: str = None,
  155. ):
  156. super().__init__(
  157. vocab_size=vocab_size,
  158. encoder_output_size=encoder_output_size,
  159. dropout_rate=dropout_rate,
  160. positional_dropout_rate=positional_dropout_rate,
  161. input_layer=input_layer,
  162. use_output_layer=use_output_layer,
  163. pos_enc_class=pos_enc_class,
  164. normalize_before=normalize_before,
  165. )
  166. if attention_dim is None:
  167. attention_dim = encoder_output_size
  168. if input_layer == "embed":
  169. self.embed = torch.nn.Sequential(
  170. torch.nn.Embedding(vocab_size, attention_dim),
  171. )
  172. elif input_layer == "linear":
  173. self.embed = torch.nn.Sequential(
  174. torch.nn.Linear(vocab_size, attention_dim),
  175. torch.nn.LayerNorm(attention_dim),
  176. torch.nn.Dropout(dropout_rate),
  177. torch.nn.ReLU(),
  178. pos_enc_class(attention_dim, positional_dropout_rate),
  179. )
  180. else:
  181. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  182. self.normalize_before = normalize_before
  183. if self.normalize_before:
  184. self.after_norm = LayerNorm(attention_dim)
  185. if use_output_layer:
  186. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  187. else:
  188. self.output_layer = None
  189. self.att_layer_num = att_layer_num
  190. self.num_blocks = num_blocks
  191. if sanm_shfit is None:
  192. sanm_shfit = (kernel_size - 1) // 2
  193. self.decoders = repeat(
  194. att_layer_num,
  195. lambda lnum: DecoderLayerSANM(
  196. attention_dim,
  197. MultiHeadedAttentionSANMDecoder(
  198. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  199. ),
  200. MultiHeadedAttentionCrossAtt(
  201. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  202. ),
  203. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  204. dropout_rate,
  205. normalize_before,
  206. concat_after,
  207. ),
  208. )
  209. if num_blocks - att_layer_num <= 0:
  210. self.decoders2 = None
  211. else:
  212. self.decoders2 = repeat(
  213. num_blocks - att_layer_num,
  214. lambda lnum: DecoderLayerSANM(
  215. attention_dim,
  216. MultiHeadedAttentionSANMDecoder(
  217. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  218. ),
  219. None,
  220. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  221. dropout_rate,
  222. normalize_before,
  223. concat_after,
  224. ),
  225. )
  226. self.decoders3 = repeat(
  227. 1,
  228. lambda lnum: DecoderLayerSANM(
  229. attention_dim,
  230. None,
  231. None,
  232. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  233. dropout_rate,
  234. normalize_before,
  235. concat_after,
  236. ),
  237. )
  238. if concat_embeds:
  239. self.embed_concat_ffn = repeat(
  240. 1,
  241. lambda lnum: DecoderLayerSANM(
  242. attention_dim + encoder_output_size,
  243. None,
  244. None,
  245. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  246. adim=attention_dim),
  247. dropout_rate,
  248. normalize_before,
  249. concat_after,
  250. ),
  251. )
  252. else:
  253. self.embed_concat_ffn = None
  254. self.concat_embeds = concat_embeds
  255. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  256. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  257. self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
  258. def forward(
  259. self,
  260. hs_pad: torch.Tensor,
  261. hlens: torch.Tensor,
  262. ys_in_pad: torch.Tensor,
  263. ys_in_lens: torch.Tensor,
  264. chunk_mask: torch.Tensor = None,
  265. pre_acoustic_embeds: torch.Tensor = None,
  266. ) -> Tuple[torch.Tensor, torch.Tensor]:
  267. """Forward decoder.
  268. Args:
  269. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  270. hlens: (batch)
  271. ys_in_pad:
  272. input token ids, int64 (batch, maxlen_out)
  273. if input_layer == "embed"
  274. input tensor (batch, maxlen_out, #mels) in the other cases
  275. ys_in_lens: (batch)
  276. Returns:
  277. (tuple): tuple containing:
  278. x: decoded token score before softmax (batch, maxlen_out, token)
  279. if use_output_layer is True,
  280. olens: (batch, )
  281. """
  282. tgt = ys_in_pad
  283. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  284. memory = hs_pad
  285. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  286. if chunk_mask is not None:
  287. memory_mask = memory_mask * chunk_mask
  288. if tgt_mask.size(1) != memory_mask.size(1):
  289. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  290. x = self.embed(tgt)
  291. if pre_acoustic_embeds is not None and self.concat_embeds:
  292. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  293. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  294. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  295. x, tgt_mask, memory, memory_mask
  296. )
  297. if self.decoders2 is not None:
  298. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  299. x, tgt_mask, memory, memory_mask
  300. )
  301. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  302. x, tgt_mask, memory, memory_mask
  303. )
  304. if self.normalize_before:
  305. x = self.after_norm(x)
  306. if self.output_layer is not None:
  307. x = self.output_layer(x)
  308. olens = tgt_mask.sum(1)
  309. return x, olens
  310. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  311. """Score."""
  312. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  313. logp, state = self.forward_one_step(
  314. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  315. cache=state
  316. )
  317. return logp.squeeze(0), state
  318. def forward_one_step(
  319. self,
  320. tgt: torch.Tensor,
  321. tgt_mask: torch.Tensor,
  322. memory: torch.Tensor,
  323. memory_mask: torch.Tensor = None,
  324. pre_acoustic_embeds: torch.Tensor = None,
  325. cache: List[torch.Tensor] = None,
  326. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  327. """Forward one step.
  328. Args:
  329. tgt: input token ids, int64 (batch, maxlen_out)
  330. tgt_mask: input token mask, (batch, maxlen_out)
  331. dtype=torch.uint8 in PyTorch 1.2-
  332. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  333. memory: encoded memory, float32 (batch, maxlen_in, feat)
  334. cache: cached output list of (batch, max_time_out-1, size)
  335. Returns:
  336. y, cache: NN output value and cache per `self.decoders`.
  337. y.shape` is (batch, maxlen_out, token)
  338. """
  339. x = tgt[:, -1:]
  340. tgt_mask = None
  341. x = self.embed(x)
  342. if pre_acoustic_embeds is not None and self.concat_embeds:
  343. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  344. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  345. if cache is None:
  346. cache_layer_num = len(self.decoders)
  347. if self.decoders2 is not None:
  348. cache_layer_num += len(self.decoders2)
  349. cache = [None] * cache_layer_num
  350. new_cache = []
  351. # for c, decoder in zip(cache, self.decoders):
  352. for i in range(self.att_layer_num):
  353. decoder = self.decoders[i]
  354. c = cache[i]
  355. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  356. x, tgt_mask, memory, memory_mask, cache=c
  357. )
  358. new_cache.append(c_ret)
  359. if self.num_blocks - self.att_layer_num >= 1:
  360. for i in range(self.num_blocks - self.att_layer_num):
  361. j = i + self.att_layer_num
  362. decoder = self.decoders2[i]
  363. c = cache[j]
  364. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  365. x, tgt_mask, memory, memory_mask, cache=c
  366. )
  367. new_cache.append(c_ret)
  368. for decoder in self.decoders3:
  369. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  370. x, tgt_mask, memory, None, cache=None
  371. )
  372. if self.normalize_before:
  373. y = self.after_norm(x[:, -1])
  374. else:
  375. y = x[:, -1]
  376. if self.output_layer is not None:
  377. y = self.output_layer(y)
  378. y = torch.log_softmax(y, dim=-1)
  379. return y, new_cache
  380. def gen_tf2torch_map_dict(self):
  381. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  382. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  383. embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
  384. map_dict_local = {
  385. ## decoder
  386. # ffn
  387. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  388. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  389. "squeeze": None,
  390. "transpose": None,
  391. }, # (256,),(256,)
  392. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  393. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  394. "squeeze": None,
  395. "transpose": None,
  396. }, # (256,),(256,)
  397. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  398. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  399. "squeeze": 0,
  400. "transpose": (1, 0),
  401. }, # (1024,256),(1,256,1024)
  402. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  403. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  404. "squeeze": None,
  405. "transpose": None,
  406. }, # (1024,),(1024,)
  407. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  408. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  409. "squeeze": None,
  410. "transpose": None,
  411. }, # (1024,),(1024,)
  412. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  413. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  414. "squeeze": None,
  415. "transpose": None,
  416. }, # (1024,),(1024,)
  417. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  418. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  419. "squeeze": 0,
  420. "transpose": (1, 0),
  421. }, # (256,1024),(1,1024,256)
  422. # fsmn
  423. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  424. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  425. tensor_name_prefix_tf),
  426. "squeeze": None,
  427. "transpose": None,
  428. }, # (256,),(256,)
  429. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  430. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  431. tensor_name_prefix_tf),
  432. "squeeze": None,
  433. "transpose": None,
  434. }, # (256,),(256,)
  435. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  436. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  437. tensor_name_prefix_tf),
  438. "squeeze": 0,
  439. "transpose": (1, 2, 0),
  440. }, # (256,1,31),(1,31,256,1)
  441. # src att
  442. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  443. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  444. "squeeze": None,
  445. "transpose": None,
  446. }, # (256,),(256,)
  447. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  448. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  449. "squeeze": None,
  450. "transpose": None,
  451. }, # (256,),(256,)
  452. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  453. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  454. "squeeze": 0,
  455. "transpose": (1, 0),
  456. }, # (256,256),(1,256,256)
  457. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  458. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  459. "squeeze": None,
  460. "transpose": None,
  461. }, # (256,),(256,)
  462. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  463. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  464. "squeeze": 0,
  465. "transpose": (1, 0),
  466. }, # (1024,256),(1,256,1024)
  467. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  468. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  469. "squeeze": None,
  470. "transpose": None,
  471. }, # (1024,),(1024,)
  472. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  473. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  474. "squeeze": 0,
  475. "transpose": (1, 0),
  476. }, # (256,256),(1,256,256)
  477. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  478. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  479. "squeeze": None,
  480. "transpose": None,
  481. }, # (256,),(256,)
  482. # dnn
  483. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  484. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  485. "squeeze": None,
  486. "transpose": None,
  487. }, # (256,),(256,)
  488. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  489. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  490. "squeeze": None,
  491. "transpose": None,
  492. }, # (256,),(256,)
  493. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  494. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  495. "squeeze": 0,
  496. "transpose": (1, 0),
  497. }, # (1024,256),(1,256,1024)
  498. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  499. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  500. "squeeze": None,
  501. "transpose": None,
  502. }, # (1024,),(1024,)
  503. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  504. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  505. "squeeze": None,
  506. "transpose": None,
  507. }, # (1024,),(1024,)
  508. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  509. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  510. "squeeze": None,
  511. "transpose": None,
  512. }, # (1024,),(1024,)
  513. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  514. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  515. "squeeze": 0,
  516. "transpose": (1, 0),
  517. }, # (256,1024),(1,1024,256)
  518. # embed_concat_ffn
  519. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  520. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  521. "squeeze": None,
  522. "transpose": None,
  523. }, # (256,),(256,)
  524. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  525. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  526. "squeeze": None,
  527. "transpose": None,
  528. }, # (256,),(256,)
  529. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  530. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  531. "squeeze": 0,
  532. "transpose": (1, 0),
  533. }, # (1024,256),(1,256,1024)
  534. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  535. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  536. "squeeze": None,
  537. "transpose": None,
  538. }, # (1024,),(1024,)
  539. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  540. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  541. "squeeze": None,
  542. "transpose": None,
  543. }, # (1024,),(1024,)
  544. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  545. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  546. "squeeze": None,
  547. "transpose": None,
  548. }, # (1024,),(1024,)
  549. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  550. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  551. "squeeze": 0,
  552. "transpose": (1, 0),
  553. }, # (256,1024),(1,1024,256)
  554. # out norm
  555. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  556. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  557. "squeeze": None,
  558. "transpose": None,
  559. }, # (256,),(256,)
  560. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  561. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  562. "squeeze": None,
  563. "transpose": None,
  564. }, # (256,),(256,)
  565. # in embed
  566. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  567. {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
  568. "squeeze": None,
  569. "transpose": None,
  570. }, # (4235,256),(4235,256)
  571. # out layer
  572. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  573. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
  574. "{}/w_embs".format(embed_tensor_name_prefix_tf)],
  575. "squeeze": [None, None],
  576. "transpose": [(1, 0), None],
  577. }, # (4235,256),(256,4235)
  578. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  579. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  580. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  581. "squeeze": [None, None],
  582. "transpose": [None, None],
  583. }, # (4235,),(4235,)
  584. }
  585. return map_dict_local
  586. def convert_tf2torch(self,
  587. var_dict_tf,
  588. var_dict_torch,
  589. ):
  590. map_dict = self.gen_tf2torch_map_dict()
  591. var_dict_torch_update = dict()
  592. decoder_layeridx_sets = set()
  593. for name in sorted(var_dict_torch.keys(), reverse=False):
  594. names = name.split('.')
  595. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  596. if names[1] == "decoders":
  597. layeridx = int(names[2])
  598. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  599. layeridx_bias = 0
  600. layeridx += layeridx_bias
  601. decoder_layeridx_sets.add(layeridx)
  602. if name_q in map_dict.keys():
  603. name_v = map_dict[name_q]["name"]
  604. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  605. data_tf = var_dict_tf[name_tf]
  606. if map_dict[name_q]["squeeze"] is not None:
  607. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  608. if map_dict[name_q]["transpose"] is not None:
  609. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  610. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  611. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  612. var_dict_torch[
  613. name].size(),
  614. data_tf.size())
  615. var_dict_torch_update[name] = data_tf
  616. logging.info(
  617. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  618. var_dict_tf[name_tf].shape))
  619. elif names[1] == "decoders2":
  620. layeridx = int(names[2])
  621. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  622. name_q = name_q.replace("decoders2", "decoders")
  623. layeridx_bias = len(decoder_layeridx_sets)
  624. layeridx += layeridx_bias
  625. if "decoders." in name:
  626. decoder_layeridx_sets.add(layeridx)
  627. if name_q in map_dict.keys():
  628. name_v = map_dict[name_q]["name"]
  629. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  630. data_tf = var_dict_tf[name_tf]
  631. if map_dict[name_q]["squeeze"] is not None:
  632. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  633. if map_dict[name_q]["transpose"] is not None:
  634. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  635. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  636. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  637. var_dict_torch[
  638. name].size(),
  639. data_tf.size())
  640. var_dict_torch_update[name] = data_tf
  641. logging.info(
  642. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  643. var_dict_tf[name_tf].shape))
  644. elif names[1] == "decoders3":
  645. layeridx = int(names[2])
  646. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  647. layeridx_bias = 0
  648. layeridx += layeridx_bias
  649. if "decoders." in name:
  650. decoder_layeridx_sets.add(layeridx)
  651. if name_q in map_dict.keys():
  652. name_v = map_dict[name_q]["name"]
  653. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  654. data_tf = var_dict_tf[name_tf]
  655. if map_dict[name_q]["squeeze"] is not None:
  656. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  657. if map_dict[name_q]["transpose"] is not None:
  658. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  659. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  660. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  661. var_dict_torch[
  662. name].size(),
  663. data_tf.size())
  664. var_dict_torch_update[name] = data_tf
  665. logging.info(
  666. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  667. var_dict_tf[name_tf].shape))
  668. elif names[1] == "embed" or names[1] == "output_layer":
  669. name_tf = map_dict[name]["name"]
  670. if isinstance(name_tf, list):
  671. idx_list = 0
  672. if name_tf[idx_list] in var_dict_tf.keys():
  673. pass
  674. else:
  675. idx_list = 1
  676. data_tf = var_dict_tf[name_tf[idx_list]]
  677. if map_dict[name]["squeeze"][idx_list] is not None:
  678. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  679. if map_dict[name]["transpose"][idx_list] is not None:
  680. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  681. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  682. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  683. var_dict_torch[
  684. name].size(),
  685. data_tf.size())
  686. var_dict_torch_update[name] = data_tf
  687. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  688. name_tf[idx_list],
  689. var_dict_tf[name_tf[
  690. idx_list]].shape))
  691. else:
  692. data_tf = var_dict_tf[name_tf]
  693. if map_dict[name]["squeeze"] is not None:
  694. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  695. if map_dict[name]["transpose"] is not None:
  696. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  697. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  698. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  699. var_dict_torch[
  700. name].size(),
  701. data_tf.size())
  702. var_dict_torch_update[name] = data_tf
  703. logging.info(
  704. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  705. var_dict_tf[name_tf].shape))
  706. elif names[1] == "after_norm":
  707. name_tf = map_dict[name]["name"]
  708. data_tf = var_dict_tf[name_tf]
  709. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  710. var_dict_torch_update[name] = data_tf
  711. logging.info(
  712. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  713. var_dict_tf[name_tf].shape))
  714. elif names[1] == "embed_concat_ffn":
  715. layeridx = int(names[2])
  716. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  717. layeridx_bias = 0
  718. layeridx += layeridx_bias
  719. if "decoders." in name:
  720. decoder_layeridx_sets.add(layeridx)
  721. if name_q in map_dict.keys():
  722. name_v = map_dict[name_q]["name"]
  723. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  724. data_tf = var_dict_tf[name_tf]
  725. if map_dict[name_q]["squeeze"] is not None:
  726. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  727. if map_dict[name_q]["transpose"] is not None:
  728. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  729. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  730. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  731. var_dict_torch[
  732. name].size(),
  733. data_tf.size())
  734. var_dict_torch_update[name] = data_tf
  735. logging.info(
  736. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  737. var_dict_tf[name_tf].shape))
  738. return var_dict_torch_update
  739. class ParaformerSANMDecoder(BaseTransformerDecoder):
  740. """
  741. Author: Speech Lab of DAMO Academy, Alibaba Group
  742. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  743. https://arxiv.org/abs/2006.01713
  744. """
  745. def __init__(
  746. self,
  747. vocab_size: int,
  748. encoder_output_size: int,
  749. attention_heads: int = 4,
  750. linear_units: int = 2048,
  751. num_blocks: int = 6,
  752. dropout_rate: float = 0.1,
  753. positional_dropout_rate: float = 0.1,
  754. self_attention_dropout_rate: float = 0.0,
  755. src_attention_dropout_rate: float = 0.0,
  756. input_layer: str = "embed",
  757. use_output_layer: bool = True,
  758. pos_enc_class=PositionalEncoding,
  759. normalize_before: bool = True,
  760. concat_after: bool = False,
  761. att_layer_num: int = 6,
  762. kernel_size: int = 21,
  763. sanm_shfit: int = 0,
  764. lora_list: List[str] = None,
  765. lora_rank: int = 8,
  766. lora_alpha: int = 16,
  767. lora_dropout: float = 0.1,
  768. tf2torch_tensor_name_prefix_torch: str = "decoder",
  769. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  770. ):
  771. super().__init__(
  772. vocab_size=vocab_size,
  773. encoder_output_size=encoder_output_size,
  774. dropout_rate=dropout_rate,
  775. positional_dropout_rate=positional_dropout_rate,
  776. input_layer=input_layer,
  777. use_output_layer=use_output_layer,
  778. pos_enc_class=pos_enc_class,
  779. normalize_before=normalize_before,
  780. )
  781. attention_dim = encoder_output_size
  782. if input_layer == "embed":
  783. self.embed = torch.nn.Sequential(
  784. torch.nn.Embedding(vocab_size, attention_dim),
  785. # pos_enc_class(attention_dim, positional_dropout_rate),
  786. )
  787. elif input_layer == "linear":
  788. self.embed = torch.nn.Sequential(
  789. torch.nn.Linear(vocab_size, attention_dim),
  790. torch.nn.LayerNorm(attention_dim),
  791. torch.nn.Dropout(dropout_rate),
  792. torch.nn.ReLU(),
  793. pos_enc_class(attention_dim, positional_dropout_rate),
  794. )
  795. else:
  796. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  797. self.normalize_before = normalize_before
  798. if self.normalize_before:
  799. self.after_norm = LayerNorm(attention_dim)
  800. if use_output_layer:
  801. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  802. else:
  803. self.output_layer = None
  804. self.att_layer_num = att_layer_num
  805. self.num_blocks = num_blocks
  806. if sanm_shfit is None:
  807. sanm_shfit = (kernel_size - 1) // 2
  808. self.decoders = repeat(
  809. att_layer_num,
  810. lambda lnum: DecoderLayerSANM(
  811. attention_dim,
  812. MultiHeadedAttentionSANMDecoder(
  813. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  814. ),
  815. MultiHeadedAttentionCrossAtt(
  816. attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
  817. ),
  818. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  819. dropout_rate,
  820. normalize_before,
  821. concat_after,
  822. ),
  823. )
  824. if num_blocks - att_layer_num <= 0:
  825. self.decoders2 = None
  826. else:
  827. self.decoders2 = repeat(
  828. num_blocks - att_layer_num,
  829. lambda lnum: DecoderLayerSANM(
  830. attention_dim,
  831. MultiHeadedAttentionSANMDecoder(
  832. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  833. ),
  834. None,
  835. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  836. dropout_rate,
  837. normalize_before,
  838. concat_after,
  839. ),
  840. )
  841. self.decoders3 = repeat(
  842. 1,
  843. lambda lnum: DecoderLayerSANM(
  844. attention_dim,
  845. None,
  846. None,
  847. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  848. dropout_rate,
  849. normalize_before,
  850. concat_after,
  851. ),
  852. )
  853. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  854. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  855. def forward(
  856. self,
  857. hs_pad: torch.Tensor,
  858. hlens: torch.Tensor,
  859. ys_in_pad: torch.Tensor,
  860. ys_in_lens: torch.Tensor,
  861. chunk_mask: torch.Tensor = None,
  862. ) -> Tuple[torch.Tensor, torch.Tensor]:
  863. """Forward decoder.
  864. Args:
  865. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  866. hlens: (batch)
  867. ys_in_pad:
  868. input token ids, int64 (batch, maxlen_out)
  869. if input_layer == "embed"
  870. input tensor (batch, maxlen_out, #mels) in the other cases
  871. ys_in_lens: (batch)
  872. Returns:
  873. (tuple): tuple containing:
  874. x: decoded token score before softmax (batch, maxlen_out, token)
  875. if use_output_layer is True,
  876. olens: (batch, )
  877. """
  878. tgt = ys_in_pad
  879. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  880. memory = hs_pad
  881. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  882. if chunk_mask is not None:
  883. memory_mask = memory_mask * chunk_mask
  884. if tgt_mask.size(1) != memory_mask.size(1):
  885. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  886. x = tgt
  887. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  888. x, tgt_mask, memory, memory_mask
  889. )
  890. if self.decoders2 is not None:
  891. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  892. x, tgt_mask, memory, memory_mask
  893. )
  894. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  895. x, tgt_mask, memory, memory_mask
  896. )
  897. if self.normalize_before:
  898. x = self.after_norm(x)
  899. if self.output_layer is not None:
  900. x = self.output_layer(x)
  901. olens = tgt_mask.sum(1)
  902. return x, olens
  903. def score(self, ys, state, x):
  904. """Score."""
  905. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  906. logp, state = self.forward_one_step(
  907. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  908. )
  909. return logp.squeeze(0), state
  910. def forward_chunk(
  911. self,
  912. memory: torch.Tensor,
  913. tgt: torch.Tensor,
  914. cache: dict = None,
  915. ) -> Tuple[torch.Tensor, torch.Tensor]:
  916. """Forward decoder.
  917. Args:
  918. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  919. hlens: (batch)
  920. ys_in_pad:
  921. input token ids, int64 (batch, maxlen_out)
  922. if input_layer == "embed"
  923. input tensor (batch, maxlen_out, #mels) in the other cases
  924. ys_in_lens: (batch)
  925. Returns:
  926. (tuple): tuple containing:
  927. x: decoded token score before softmax (batch, maxlen_out, token)
  928. if use_output_layer is True,
  929. olens: (batch, )
  930. """
  931. x = tgt
  932. if cache["decode_fsmn"] is None:
  933. cache_layer_num = len(self.decoders)
  934. if self.decoders2 is not None:
  935. cache_layer_num += len(self.decoders2)
  936. new_cache = [None] * cache_layer_num
  937. else:
  938. new_cache = cache["decode_fsmn"]
  939. for i in range(self.att_layer_num):
  940. decoder = self.decoders[i]
  941. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  942. x, None, memory, None, cache=new_cache[i]
  943. )
  944. new_cache[i] = c_ret
  945. if self.num_blocks - self.att_layer_num > 1:
  946. for i in range(self.num_blocks - self.att_layer_num):
  947. j = i + self.att_layer_num
  948. decoder = self.decoders2[i]
  949. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  950. x, None, memory, None, cache=new_cache[j]
  951. )
  952. new_cache[j] = c_ret
  953. for decoder in self.decoders3:
  954. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  955. x, None, memory, None, cache=None
  956. )
  957. if self.normalize_before:
  958. x = self.after_norm(x)
  959. if self.output_layer is not None:
  960. x = self.output_layer(x)
  961. cache["decode_fsmn"] = new_cache
  962. return x
  963. def forward_one_step(
  964. self,
  965. tgt: torch.Tensor,
  966. tgt_mask: torch.Tensor,
  967. memory: torch.Tensor,
  968. cache: List[torch.Tensor] = None,
  969. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  970. """Forward one step.
  971. Args:
  972. tgt: input token ids, int64 (batch, maxlen_out)
  973. tgt_mask: input token mask, (batch, maxlen_out)
  974. dtype=torch.uint8 in PyTorch 1.2-
  975. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  976. memory: encoded memory, float32 (batch, maxlen_in, feat)
  977. cache: cached output list of (batch, max_time_out-1, size)
  978. Returns:
  979. y, cache: NN output value and cache per `self.decoders`.
  980. y.shape` is (batch, maxlen_out, token)
  981. """
  982. x = self.embed(tgt)
  983. if cache is None:
  984. cache_layer_num = len(self.decoders)
  985. if self.decoders2 is not None:
  986. cache_layer_num += len(self.decoders2)
  987. cache = [None] * cache_layer_num
  988. new_cache = []
  989. # for c, decoder in zip(cache, self.decoders):
  990. for i in range(self.att_layer_num):
  991. decoder = self.decoders[i]
  992. c = cache[i]
  993. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  994. x, tgt_mask, memory, None, cache=c
  995. )
  996. new_cache.append(c_ret)
  997. if self.num_blocks - self.att_layer_num > 1:
  998. for i in range(self.num_blocks - self.att_layer_num):
  999. j = i + self.att_layer_num
  1000. decoder = self.decoders2[i]
  1001. c = cache[j]
  1002. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
  1003. x, tgt_mask, memory, None, cache=c
  1004. )
  1005. new_cache.append(c_ret)
  1006. for decoder in self.decoders3:
  1007. x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
  1008. x, tgt_mask, memory, None, cache=None
  1009. )
  1010. if self.normalize_before:
  1011. y = self.after_norm(x[:, -1])
  1012. else:
  1013. y = x[:, -1]
  1014. if self.output_layer is not None:
  1015. y = torch.log_softmax(self.output_layer(y), dim=-1)
  1016. return y, new_cache
  1017. def gen_tf2torch_map_dict(self):
  1018. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  1019. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  1020. map_dict_local = {
  1021. ## decoder
  1022. # ffn
  1023. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1024. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1025. "squeeze": None,
  1026. "transpose": None,
  1027. }, # (256,),(256,)
  1028. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1029. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  1030. "squeeze": None,
  1031. "transpose": None,
  1032. }, # (256,),(256,)
  1033. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1034. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  1035. "squeeze": 0,
  1036. "transpose": (1, 0),
  1037. }, # (1024,256),(1,256,1024)
  1038. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1039. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  1040. "squeeze": None,
  1041. "transpose": None,
  1042. }, # (1024,),(1024,)
  1043. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1044. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1045. "squeeze": None,
  1046. "transpose": None,
  1047. }, # (1024,),(1024,)
  1048. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1049. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1050. "squeeze": None,
  1051. "transpose": None,
  1052. }, # (1024,),(1024,)
  1053. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1054. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1055. "squeeze": 0,
  1056. "transpose": (1, 0),
  1057. }, # (256,1024),(1,1024,256)
  1058. # fsmn
  1059. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  1060. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  1061. tensor_name_prefix_tf),
  1062. "squeeze": None,
  1063. "transpose": None,
  1064. }, # (256,),(256,)
  1065. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  1066. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  1067. tensor_name_prefix_tf),
  1068. "squeeze": None,
  1069. "transpose": None,
  1070. }, # (256,),(256,)
  1071. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  1072. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  1073. tensor_name_prefix_tf),
  1074. "squeeze": 0,
  1075. "transpose": (1, 2, 0),
  1076. }, # (256,1,31),(1,31,256,1)
  1077. # src att
  1078. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  1079. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1080. "squeeze": None,
  1081. "transpose": None,
  1082. }, # (256,),(256,)
  1083. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  1084. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  1085. "squeeze": None,
  1086. "transpose": None,
  1087. }, # (256,),(256,)
  1088. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  1089. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  1090. "squeeze": 0,
  1091. "transpose": (1, 0),
  1092. }, # (256,256),(1,256,256)
  1093. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  1094. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  1095. "squeeze": None,
  1096. "transpose": None,
  1097. }, # (256,),(256,)
  1098. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  1099. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1100. "squeeze": 0,
  1101. "transpose": (1, 0),
  1102. }, # (1024,256),(1,256,1024)
  1103. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  1104. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  1105. "squeeze": None,
  1106. "transpose": None,
  1107. }, # (1024,),(1024,)
  1108. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  1109. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  1110. "squeeze": 0,
  1111. "transpose": (1, 0),
  1112. }, # (256,256),(1,256,256)
  1113. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  1114. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  1115. "squeeze": None,
  1116. "transpose": None,
  1117. }, # (256,),(256,)
  1118. # dnn
  1119. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1120. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1121. "squeeze": None,
  1122. "transpose": None,
  1123. }, # (256,),(256,)
  1124. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1125. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  1126. "squeeze": None,
  1127. "transpose": None,
  1128. }, # (256,),(256,)
  1129. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1130. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  1131. "squeeze": 0,
  1132. "transpose": (1, 0),
  1133. }, # (1024,256),(1,256,1024)
  1134. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1135. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  1136. "squeeze": None,
  1137. "transpose": None,
  1138. }, # (1024,),(1024,)
  1139. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1140. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1141. "squeeze": None,
  1142. "transpose": None,
  1143. }, # (1024,),(1024,)
  1144. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1145. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1146. "squeeze": None,
  1147. "transpose": None,
  1148. }, # (1024,),(1024,)
  1149. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1150. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1151. "squeeze": 0,
  1152. "transpose": (1, 0),
  1153. }, # (256,1024),(1,1024,256)
  1154. # embed_concat_ffn
  1155. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1156. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1157. "squeeze": None,
  1158. "transpose": None,
  1159. }, # (256,),(256,)
  1160. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1161. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  1162. "squeeze": None,
  1163. "transpose": None,
  1164. }, # (256,),(256,)
  1165. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1166. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  1167. "squeeze": 0,
  1168. "transpose": (1, 0),
  1169. }, # (1024,256),(1,256,1024)
  1170. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1171. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  1172. "squeeze": None,
  1173. "transpose": None,
  1174. }, # (1024,),(1024,)
  1175. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1176. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1177. "squeeze": None,
  1178. "transpose": None,
  1179. }, # (1024,),(1024,)
  1180. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1181. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1182. "squeeze": None,
  1183. "transpose": None,
  1184. }, # (1024,),(1024,)
  1185. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1186. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1187. "squeeze": 0,
  1188. "transpose": (1, 0),
  1189. }, # (256,1024),(1,1024,256)
  1190. # out norm
  1191. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  1192. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1193. "squeeze": None,
  1194. "transpose": None,
  1195. }, # (256,),(256,)
  1196. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  1197. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  1198. "squeeze": None,
  1199. "transpose": None,
  1200. }, # (256,),(256,)
  1201. # in embed
  1202. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  1203. {"name": "{}/w_embs".format(tensor_name_prefix_tf),
  1204. "squeeze": None,
  1205. "transpose": None,
  1206. }, # (4235,256),(4235,256)
  1207. # out layer
  1208. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  1209. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
  1210. "squeeze": [None, None],
  1211. "transpose": [(1, 0), None],
  1212. }, # (4235,256),(256,4235)
  1213. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  1214. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  1215. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  1216. "squeeze": [None, None],
  1217. "transpose": [None, None],
  1218. }, # (4235,),(4235,)
  1219. }
  1220. return map_dict_local
  1221. def convert_tf2torch(self,
  1222. var_dict_tf,
  1223. var_dict_torch,
  1224. ):
  1225. map_dict = self.gen_tf2torch_map_dict()
  1226. var_dict_torch_update = dict()
  1227. decoder_layeridx_sets = set()
  1228. for name in sorted(var_dict_torch.keys(), reverse=False):
  1229. names = name.split('.')
  1230. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  1231. if names[1] == "decoders":
  1232. layeridx = int(names[2])
  1233. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1234. layeridx_bias = 0
  1235. layeridx += layeridx_bias
  1236. decoder_layeridx_sets.add(layeridx)
  1237. if name_q in map_dict.keys():
  1238. name_v = map_dict[name_q]["name"]
  1239. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1240. data_tf = var_dict_tf[name_tf]
  1241. if map_dict[name_q]["squeeze"] is not None:
  1242. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1243. if map_dict[name_q]["transpose"] is not None:
  1244. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1245. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1246. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1247. var_dict_torch[
  1248. name].size(),
  1249. data_tf.size())
  1250. var_dict_torch_update[name] = data_tf
  1251. logging.info(
  1252. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1253. var_dict_tf[name_tf].shape))
  1254. elif names[1] == "decoders2":
  1255. layeridx = int(names[2])
  1256. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1257. name_q = name_q.replace("decoders2", "decoders")
  1258. layeridx_bias = len(decoder_layeridx_sets)
  1259. layeridx += layeridx_bias
  1260. if "decoders." in name:
  1261. decoder_layeridx_sets.add(layeridx)
  1262. if name_q in map_dict.keys():
  1263. name_v = map_dict[name_q]["name"]
  1264. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1265. data_tf = var_dict_tf[name_tf]
  1266. if map_dict[name_q]["squeeze"] is not None:
  1267. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1268. if map_dict[name_q]["transpose"] is not None:
  1269. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1270. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1271. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1272. var_dict_torch[
  1273. name].size(),
  1274. data_tf.size())
  1275. var_dict_torch_update[name] = data_tf
  1276. logging.info(
  1277. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1278. var_dict_tf[name_tf].shape))
  1279. elif names[1] == "decoders3":
  1280. layeridx = int(names[2])
  1281. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1282. layeridx_bias = 0
  1283. layeridx += layeridx_bias
  1284. if "decoders." in name:
  1285. decoder_layeridx_sets.add(layeridx)
  1286. if name_q in map_dict.keys():
  1287. name_v = map_dict[name_q]["name"]
  1288. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1289. data_tf = var_dict_tf[name_tf]
  1290. if map_dict[name_q]["squeeze"] is not None:
  1291. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1292. if map_dict[name_q]["transpose"] is not None:
  1293. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1294. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1295. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1296. var_dict_torch[
  1297. name].size(),
  1298. data_tf.size())
  1299. var_dict_torch_update[name] = data_tf
  1300. logging.info(
  1301. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1302. var_dict_tf[name_tf].shape))
  1303. elif names[1] == "embed" or names[1] == "output_layer":
  1304. name_tf = map_dict[name]["name"]
  1305. if isinstance(name_tf, list):
  1306. idx_list = 0
  1307. if name_tf[idx_list] in var_dict_tf.keys():
  1308. pass
  1309. else:
  1310. idx_list = 1
  1311. data_tf = var_dict_tf[name_tf[idx_list]]
  1312. if map_dict[name]["squeeze"][idx_list] is not None:
  1313. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  1314. if map_dict[name]["transpose"][idx_list] is not None:
  1315. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  1316. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1317. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1318. var_dict_torch[
  1319. name].size(),
  1320. data_tf.size())
  1321. var_dict_torch_update[name] = data_tf
  1322. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  1323. name_tf[idx_list],
  1324. var_dict_tf[name_tf[
  1325. idx_list]].shape))
  1326. else:
  1327. data_tf = var_dict_tf[name_tf]
  1328. if map_dict[name]["squeeze"] is not None:
  1329. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1330. if map_dict[name]["transpose"] is not None:
  1331. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1332. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1333. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1334. var_dict_torch[
  1335. name].size(),
  1336. data_tf.size())
  1337. var_dict_torch_update[name] = data_tf
  1338. logging.info(
  1339. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1340. var_dict_tf[name_tf].shape))
  1341. elif names[1] == "after_norm":
  1342. name_tf = map_dict[name]["name"]
  1343. data_tf = var_dict_tf[name_tf]
  1344. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1345. var_dict_torch_update[name] = data_tf
  1346. logging.info(
  1347. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1348. var_dict_tf[name_tf].shape))
  1349. elif names[1] == "embed_concat_ffn":
  1350. layeridx = int(names[2])
  1351. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1352. layeridx_bias = 0
  1353. layeridx += layeridx_bias
  1354. if "decoders." in name:
  1355. decoder_layeridx_sets.add(layeridx)
  1356. if name_q in map_dict.keys():
  1357. name_v = map_dict[name_q]["name"]
  1358. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1359. data_tf = var_dict_tf[name_tf]
  1360. if map_dict[name_q]["squeeze"] is not None:
  1361. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1362. if map_dict[name_q]["transpose"] is not None:
  1363. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1364. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1365. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1366. var_dict_torch[
  1367. name].size(),
  1368. data_tf.size())
  1369. var_dict_torch_update[name] = data_tf
  1370. logging.info(
  1371. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1372. var_dict_tf[name_tf].shape))
  1373. return var_dict_torch_update