sanm_decoder.py 76 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541
  1. from typing import List
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from funasr.modules.streaming_utils import utils as myutils
  8. from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
  9. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  10. from funasr.modules.embedding import PositionalEncoding
  11. from funasr.modules.layer_norm import LayerNorm
  12. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  13. from funasr.modules.repeat import repeat
  14. class DecoderLayerSANM(nn.Module):
  15. """Single decoder layer module.
  16. Args:
  17. size (int): Input dimension.
  18. self_attn (torch.nn.Module): Self-attention module instance.
  19. `MultiHeadedAttention` instance can be used as the argument.
  20. src_attn (torch.nn.Module): Self-attention module instance.
  21. `MultiHeadedAttention` instance can be used as the argument.
  22. feed_forward (torch.nn.Module): Feed-forward module instance.
  23. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  24. can be used as the argument.
  25. dropout_rate (float): Dropout rate.
  26. normalize_before (bool): Whether to use layer_norm before the first block.
  27. concat_after (bool): Whether to concat attention layer's input and output.
  28. if True, additional linear will be applied.
  29. i.e. x -> x + linear(concat(x, att(x)))
  30. if False, no additional linear will be applied. i.e. x -> x + att(x)
  31. """
  32. def __init__(
  33. self,
  34. size,
  35. self_attn,
  36. src_attn,
  37. feed_forward,
  38. dropout_rate,
  39. normalize_before=True,
  40. concat_after=False,
  41. ):
  42. """Construct an DecoderLayer object."""
  43. super(DecoderLayerSANM, self).__init__()
  44. self.size = size
  45. self.self_attn = self_attn
  46. self.src_attn = src_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(size)
  49. if self_attn is not None:
  50. self.norm2 = LayerNorm(size)
  51. if src_attn is not None:
  52. self.norm3 = LayerNorm(size)
  53. self.dropout = nn.Dropout(dropout_rate)
  54. self.normalize_before = normalize_before
  55. self.concat_after = concat_after
  56. if self.concat_after:
  57. self.concat_linear1 = nn.Linear(size + size, size)
  58. self.concat_linear2 = nn.Linear(size + size, size)
  59. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  60. """Compute decoded features.
  61. Args:
  62. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  63. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  64. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  65. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  66. cache (List[torch.Tensor]): List of cached tensors.
  67. Each tensor shape should be (#batch, maxlen_out - 1, size).
  68. Returns:
  69. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  70. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  71. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  72. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  73. """
  74. # tgt = self.dropout(tgt)
  75. residual = tgt
  76. if self.normalize_before:
  77. tgt = self.norm1(tgt)
  78. tgt = self.feed_forward(tgt)
  79. x = tgt
  80. if self.self_attn:
  81. if self.normalize_before:
  82. tgt = self.norm2(tgt)
  83. x, _ = self.self_attn(tgt, tgt_mask)
  84. x = residual + self.dropout(x)
  85. if self.src_attn is not None:
  86. residual = x
  87. if self.normalize_before:
  88. x = self.norm3(x)
  89. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  90. return x, tgt_mask, memory, memory_mask, cache
  91. def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  92. """Compute decoded features.
  93. Args:
  94. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  95. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  96. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  97. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  98. cache (List[torch.Tensor]): List of cached tensors.
  99. Each tensor shape should be (#batch, maxlen_out - 1, size).
  100. Returns:
  101. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  102. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  103. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  104. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  105. """
  106. # tgt = self.dropout(tgt)
  107. residual = tgt
  108. if self.normalize_before:
  109. tgt = self.norm1(tgt)
  110. tgt = self.feed_forward(tgt)
  111. x = tgt
  112. if self.self_attn:
  113. if self.normalize_before:
  114. tgt = self.norm2(tgt)
  115. if self.training:
  116. cache = None
  117. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  118. x = residual + self.dropout(x)
  119. if self.src_attn is not None:
  120. residual = x
  121. if self.normalize_before:
  122. x = self.norm3(x)
  123. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  124. return x, tgt_mask, memory, memory_mask, cache
  125. def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
  126. """Compute decoded features.
  127. Args:
  128. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  129. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  130. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  131. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  132. cache (List[torch.Tensor]): List of cached tensors.
  133. Each tensor shape should be (#batch, maxlen_out - 1, size).
  134. Returns:
  135. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  136. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  137. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  138. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  139. """
  140. residual = tgt
  141. if self.normalize_before:
  142. tgt = self.norm1(tgt)
  143. tgt = self.feed_forward(tgt)
  144. x = tgt
  145. if self.self_attn:
  146. if self.normalize_before:
  147. tgt = self.norm2(tgt)
  148. x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
  149. x = residual + self.dropout(x)
  150. if self.src_attn is not None:
  151. residual = x
  152. if self.normalize_before:
  153. x = self.norm3(x)
  154. x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
  155. x = residual + x
  156. return x, memory, fsmn_cache, opt_cache
  157. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  158. """
  159. Author: Speech Lab of DAMO Academy, Alibaba Group
  160. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  161. https://arxiv.org/abs/2006.01713
  162. """
  163. def __init__(
  164. self,
  165. vocab_size: int,
  166. encoder_output_size: int,
  167. attention_heads: int = 4,
  168. linear_units: int = 2048,
  169. num_blocks: int = 6,
  170. dropout_rate: float = 0.1,
  171. positional_dropout_rate: float = 0.1,
  172. self_attention_dropout_rate: float = 0.0,
  173. src_attention_dropout_rate: float = 0.0,
  174. input_layer: str = "embed",
  175. use_output_layer: bool = True,
  176. pos_enc_class=PositionalEncoding,
  177. normalize_before: bool = True,
  178. concat_after: bool = False,
  179. att_layer_num: int = 6,
  180. kernel_size: int = 21,
  181. sanm_shfit: int = None,
  182. concat_embeds: bool = False,
  183. attention_dim: int = None,
  184. tf2torch_tensor_name_prefix_torch: str = "decoder",
  185. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  186. embed_tensor_name_prefix_tf: str = None,
  187. ):
  188. super().__init__(
  189. vocab_size=vocab_size,
  190. encoder_output_size=encoder_output_size,
  191. dropout_rate=dropout_rate,
  192. positional_dropout_rate=positional_dropout_rate,
  193. input_layer=input_layer,
  194. use_output_layer=use_output_layer,
  195. pos_enc_class=pos_enc_class,
  196. normalize_before=normalize_before,
  197. )
  198. if attention_dim is None:
  199. attention_dim = encoder_output_size
  200. if input_layer == "embed":
  201. self.embed = torch.nn.Sequential(
  202. torch.nn.Embedding(vocab_size, attention_dim),
  203. )
  204. elif input_layer == "linear":
  205. self.embed = torch.nn.Sequential(
  206. torch.nn.Linear(vocab_size, attention_dim),
  207. torch.nn.LayerNorm(attention_dim),
  208. torch.nn.Dropout(dropout_rate),
  209. torch.nn.ReLU(),
  210. pos_enc_class(attention_dim, positional_dropout_rate),
  211. )
  212. else:
  213. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  214. self.normalize_before = normalize_before
  215. if self.normalize_before:
  216. self.after_norm = LayerNorm(attention_dim)
  217. if use_output_layer:
  218. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  219. else:
  220. self.output_layer = None
  221. self.att_layer_num = att_layer_num
  222. self.num_blocks = num_blocks
  223. if sanm_shfit is None:
  224. sanm_shfit = (kernel_size - 1) // 2
  225. self.decoders = repeat(
  226. att_layer_num,
  227. lambda lnum: DecoderLayerSANM(
  228. attention_dim,
  229. MultiHeadedAttentionSANMDecoder(
  230. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  231. ),
  232. MultiHeadedAttentionCrossAtt(
  233. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  234. ),
  235. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  236. dropout_rate,
  237. normalize_before,
  238. concat_after,
  239. ),
  240. )
  241. if num_blocks - att_layer_num <= 0:
  242. self.decoders2 = None
  243. else:
  244. self.decoders2 = repeat(
  245. num_blocks - att_layer_num,
  246. lambda lnum: DecoderLayerSANM(
  247. attention_dim,
  248. MultiHeadedAttentionSANMDecoder(
  249. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  250. ),
  251. None,
  252. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  253. dropout_rate,
  254. normalize_before,
  255. concat_after,
  256. ),
  257. )
  258. self.decoders3 = repeat(
  259. 1,
  260. lambda lnum: DecoderLayerSANM(
  261. attention_dim,
  262. None,
  263. None,
  264. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  265. dropout_rate,
  266. normalize_before,
  267. concat_after,
  268. ),
  269. )
  270. if concat_embeds:
  271. self.embed_concat_ffn = repeat(
  272. 1,
  273. lambda lnum: DecoderLayerSANM(
  274. attention_dim + encoder_output_size,
  275. None,
  276. None,
  277. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  278. adim=attention_dim),
  279. dropout_rate,
  280. normalize_before,
  281. concat_after,
  282. ),
  283. )
  284. else:
  285. self.embed_concat_ffn = None
  286. self.concat_embeds = concat_embeds
  287. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  288. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  289. self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
  290. def forward(
  291. self,
  292. hs_pad: torch.Tensor,
  293. hlens: torch.Tensor,
  294. ys_in_pad: torch.Tensor,
  295. ys_in_lens: torch.Tensor,
  296. chunk_mask: torch.Tensor = None,
  297. pre_acoustic_embeds: torch.Tensor = None,
  298. ) -> Tuple[torch.Tensor, torch.Tensor]:
  299. """Forward decoder.
  300. Args:
  301. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  302. hlens: (batch)
  303. ys_in_pad:
  304. input token ids, int64 (batch, maxlen_out)
  305. if input_layer == "embed"
  306. input tensor (batch, maxlen_out, #mels) in the other cases
  307. ys_in_lens: (batch)
  308. Returns:
  309. (tuple): tuple containing:
  310. x: decoded token score before softmax (batch, maxlen_out, token)
  311. if use_output_layer is True,
  312. olens: (batch, )
  313. """
  314. tgt = ys_in_pad
  315. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  316. memory = hs_pad
  317. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  318. if chunk_mask is not None:
  319. memory_mask = memory_mask * chunk_mask
  320. if tgt_mask.size(1) != memory_mask.size(1):
  321. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  322. x = self.embed(tgt)
  323. if pre_acoustic_embeds is not None and self.concat_embeds:
  324. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  325. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  326. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  327. x, tgt_mask, memory, memory_mask
  328. )
  329. if self.decoders2 is not None:
  330. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  331. x, tgt_mask, memory, memory_mask
  332. )
  333. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  334. x, tgt_mask, memory, memory_mask
  335. )
  336. if self.normalize_before:
  337. x = self.after_norm(x)
  338. if self.output_layer is not None:
  339. x = self.output_layer(x)
  340. olens = tgt_mask.sum(1)
  341. return x, olens
  342. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  343. """Score."""
  344. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  345. logp, state = self.forward_one_step(
  346. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  347. cache=state
  348. )
  349. return logp.squeeze(0), state
  350. def forward_one_step(
  351. self,
  352. tgt: torch.Tensor,
  353. tgt_mask: torch.Tensor,
  354. memory: torch.Tensor,
  355. memory_mask: torch.Tensor = None,
  356. pre_acoustic_embeds: torch.Tensor = None,
  357. cache: List[torch.Tensor] = None,
  358. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  359. """Forward one step.
  360. Args:
  361. tgt: input token ids, int64 (batch, maxlen_out)
  362. tgt_mask: input token mask, (batch, maxlen_out)
  363. dtype=torch.uint8 in PyTorch 1.2-
  364. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  365. memory: encoded memory, float32 (batch, maxlen_in, feat)
  366. cache: cached output list of (batch, max_time_out-1, size)
  367. Returns:
  368. y, cache: NN output value and cache per `self.decoders`.
  369. y.shape` is (batch, maxlen_out, token)
  370. """
  371. x = tgt[:, -1:]
  372. tgt_mask = None
  373. x = self.embed(x)
  374. if pre_acoustic_embeds is not None and self.concat_embeds:
  375. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  376. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  377. if cache is None:
  378. cache_layer_num = len(self.decoders)
  379. if self.decoders2 is not None:
  380. cache_layer_num += len(self.decoders2)
  381. cache = [None] * cache_layer_num
  382. new_cache = []
  383. # for c, decoder in zip(cache, self.decoders):
  384. for i in range(self.att_layer_num):
  385. decoder = self.decoders[i]
  386. c = cache[i]
  387. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  388. x, tgt_mask, memory, memory_mask, cache=c
  389. )
  390. new_cache.append(c_ret)
  391. if self.num_blocks - self.att_layer_num >= 1:
  392. for i in range(self.num_blocks - self.att_layer_num):
  393. j = i + self.att_layer_num
  394. decoder = self.decoders2[i]
  395. c = cache[j]
  396. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  397. x, tgt_mask, memory, memory_mask, cache=c
  398. )
  399. new_cache.append(c_ret)
  400. for decoder in self.decoders3:
  401. x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
  402. x, tgt_mask, memory, None, cache=None
  403. )
  404. if self.normalize_before:
  405. y = self.after_norm(x[:, -1])
  406. else:
  407. y = x[:, -1]
  408. if self.output_layer is not None:
  409. y = self.output_layer(y)
  410. y = torch.log_softmax(y, dim=-1)
  411. return y, new_cache
  412. def gen_tf2torch_map_dict(self):
  413. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  414. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  415. embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
  416. map_dict_local = {
  417. ## decoder
  418. # ffn
  419. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  420. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  421. "squeeze": None,
  422. "transpose": None,
  423. }, # (256,),(256,)
  424. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  425. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  426. "squeeze": None,
  427. "transpose": None,
  428. }, # (256,),(256,)
  429. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  430. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  431. "squeeze": 0,
  432. "transpose": (1, 0),
  433. }, # (1024,256),(1,256,1024)
  434. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  435. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  436. "squeeze": None,
  437. "transpose": None,
  438. }, # (1024,),(1024,)
  439. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  440. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  441. "squeeze": None,
  442. "transpose": None,
  443. }, # (1024,),(1024,)
  444. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  445. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  446. "squeeze": None,
  447. "transpose": None,
  448. }, # (1024,),(1024,)
  449. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  450. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  451. "squeeze": 0,
  452. "transpose": (1, 0),
  453. }, # (256,1024),(1,1024,256)
  454. # fsmn
  455. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  456. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  457. tensor_name_prefix_tf),
  458. "squeeze": None,
  459. "transpose": None,
  460. }, # (256,),(256,)
  461. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  462. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  463. tensor_name_prefix_tf),
  464. "squeeze": None,
  465. "transpose": None,
  466. }, # (256,),(256,)
  467. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  468. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  469. tensor_name_prefix_tf),
  470. "squeeze": 0,
  471. "transpose": (1, 2, 0),
  472. }, # (256,1,31),(1,31,256,1)
  473. # src att
  474. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  475. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  476. "squeeze": None,
  477. "transpose": None,
  478. }, # (256,),(256,)
  479. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  480. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  481. "squeeze": None,
  482. "transpose": None,
  483. }, # (256,),(256,)
  484. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  485. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  486. "squeeze": 0,
  487. "transpose": (1, 0),
  488. }, # (256,256),(1,256,256)
  489. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  490. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  491. "squeeze": None,
  492. "transpose": None,
  493. }, # (256,),(256,)
  494. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  495. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  496. "squeeze": 0,
  497. "transpose": (1, 0),
  498. }, # (1024,256),(1,256,1024)
  499. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  500. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  501. "squeeze": None,
  502. "transpose": None,
  503. }, # (1024,),(1024,)
  504. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  505. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  506. "squeeze": 0,
  507. "transpose": (1, 0),
  508. }, # (256,256),(1,256,256)
  509. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  510. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  511. "squeeze": None,
  512. "transpose": None,
  513. }, # (256,),(256,)
  514. # dnn
  515. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  516. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  517. "squeeze": None,
  518. "transpose": None,
  519. }, # (256,),(256,)
  520. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  521. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  522. "squeeze": None,
  523. "transpose": None,
  524. }, # (256,),(256,)
  525. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  526. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  527. "squeeze": 0,
  528. "transpose": (1, 0),
  529. }, # (1024,256),(1,256,1024)
  530. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  531. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  532. "squeeze": None,
  533. "transpose": None,
  534. }, # (1024,),(1024,)
  535. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  536. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  537. "squeeze": None,
  538. "transpose": None,
  539. }, # (1024,),(1024,)
  540. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  541. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  542. "squeeze": None,
  543. "transpose": None,
  544. }, # (1024,),(1024,)
  545. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  546. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  547. "squeeze": 0,
  548. "transpose": (1, 0),
  549. }, # (256,1024),(1,1024,256)
  550. # embed_concat_ffn
  551. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  552. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  553. "squeeze": None,
  554. "transpose": None,
  555. }, # (256,),(256,)
  556. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  557. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  558. "squeeze": None,
  559. "transpose": None,
  560. }, # (256,),(256,)
  561. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  562. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  563. "squeeze": 0,
  564. "transpose": (1, 0),
  565. }, # (1024,256),(1,256,1024)
  566. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  567. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  568. "squeeze": None,
  569. "transpose": None,
  570. }, # (1024,),(1024,)
  571. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  572. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  573. "squeeze": None,
  574. "transpose": None,
  575. }, # (1024,),(1024,)
  576. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  577. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  578. "squeeze": None,
  579. "transpose": None,
  580. }, # (1024,),(1024,)
  581. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  582. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  583. "squeeze": 0,
  584. "transpose": (1, 0),
  585. }, # (256,1024),(1,1024,256)
  586. # out norm
  587. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  588. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  589. "squeeze": None,
  590. "transpose": None,
  591. }, # (256,),(256,)
  592. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  593. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  594. "squeeze": None,
  595. "transpose": None,
  596. }, # (256,),(256,)
  597. # in embed
  598. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  599. {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
  600. "squeeze": None,
  601. "transpose": None,
  602. }, # (4235,256),(4235,256)
  603. # out layer
  604. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  605. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
  606. "{}/w_embs".format(embed_tensor_name_prefix_tf)],
  607. "squeeze": [None, None],
  608. "transpose": [(1, 0), None],
  609. }, # (4235,256),(256,4235)
  610. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  611. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  612. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  613. "squeeze": [None, None],
  614. "transpose": [None, None],
  615. }, # (4235,),(4235,)
  616. }
  617. return map_dict_local
  618. def convert_tf2torch(self,
  619. var_dict_tf,
  620. var_dict_torch,
  621. ):
  622. map_dict = self.gen_tf2torch_map_dict()
  623. var_dict_torch_update = dict()
  624. decoder_layeridx_sets = set()
  625. for name in sorted(var_dict_torch.keys(), reverse=False):
  626. names = name.split('.')
  627. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  628. if names[1] == "decoders":
  629. layeridx = int(names[2])
  630. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  631. layeridx_bias = 0
  632. layeridx += layeridx_bias
  633. decoder_layeridx_sets.add(layeridx)
  634. if name_q in map_dict.keys():
  635. name_v = map_dict[name_q]["name"]
  636. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  637. data_tf = var_dict_tf[name_tf]
  638. if map_dict[name_q]["squeeze"] is not None:
  639. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  640. if map_dict[name_q]["transpose"] is not None:
  641. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  642. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  643. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  644. var_dict_torch[
  645. name].size(),
  646. data_tf.size())
  647. var_dict_torch_update[name] = data_tf
  648. logging.info(
  649. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  650. var_dict_tf[name_tf].shape))
  651. elif names[1] == "decoders2":
  652. layeridx = int(names[2])
  653. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  654. name_q = name_q.replace("decoders2", "decoders")
  655. layeridx_bias = len(decoder_layeridx_sets)
  656. layeridx += layeridx_bias
  657. if "decoders." in name:
  658. decoder_layeridx_sets.add(layeridx)
  659. if name_q in map_dict.keys():
  660. name_v = map_dict[name_q]["name"]
  661. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  662. data_tf = var_dict_tf[name_tf]
  663. if map_dict[name_q]["squeeze"] is not None:
  664. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  665. if map_dict[name_q]["transpose"] is not None:
  666. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  667. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  668. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  669. var_dict_torch[
  670. name].size(),
  671. data_tf.size())
  672. var_dict_torch_update[name] = data_tf
  673. logging.info(
  674. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  675. var_dict_tf[name_tf].shape))
  676. elif names[1] == "decoders3":
  677. layeridx = int(names[2])
  678. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  679. layeridx_bias = 0
  680. layeridx += layeridx_bias
  681. if "decoders." in name:
  682. decoder_layeridx_sets.add(layeridx)
  683. if name_q in map_dict.keys():
  684. name_v = map_dict[name_q]["name"]
  685. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  686. data_tf = var_dict_tf[name_tf]
  687. if map_dict[name_q]["squeeze"] is not None:
  688. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  689. if map_dict[name_q]["transpose"] is not None:
  690. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  691. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  692. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  693. var_dict_torch[
  694. name].size(),
  695. data_tf.size())
  696. var_dict_torch_update[name] = data_tf
  697. logging.info(
  698. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  699. var_dict_tf[name_tf].shape))
  700. elif names[1] == "embed" or names[1] == "output_layer":
  701. name_tf = map_dict[name]["name"]
  702. if isinstance(name_tf, list):
  703. idx_list = 0
  704. if name_tf[idx_list] in var_dict_tf.keys():
  705. pass
  706. else:
  707. idx_list = 1
  708. data_tf = var_dict_tf[name_tf[idx_list]]
  709. if map_dict[name]["squeeze"][idx_list] is not None:
  710. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  711. if map_dict[name]["transpose"][idx_list] is not None:
  712. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  713. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  714. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  715. var_dict_torch[
  716. name].size(),
  717. data_tf.size())
  718. var_dict_torch_update[name] = data_tf
  719. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  720. name_tf[idx_list],
  721. var_dict_tf[name_tf[
  722. idx_list]].shape))
  723. else:
  724. data_tf = var_dict_tf[name_tf]
  725. if map_dict[name]["squeeze"] is not None:
  726. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  727. if map_dict[name]["transpose"] is not None:
  728. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  729. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  730. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  731. var_dict_torch[
  732. name].size(),
  733. data_tf.size())
  734. var_dict_torch_update[name] = data_tf
  735. logging.info(
  736. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  737. var_dict_tf[name_tf].shape))
  738. elif names[1] == "after_norm":
  739. name_tf = map_dict[name]["name"]
  740. data_tf = var_dict_tf[name_tf]
  741. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  742. var_dict_torch_update[name] = data_tf
  743. logging.info(
  744. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  745. var_dict_tf[name_tf].shape))
  746. elif names[1] == "embed_concat_ffn":
  747. layeridx = int(names[2])
  748. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  749. layeridx_bias = 0
  750. layeridx += layeridx_bias
  751. if "decoders." in name:
  752. decoder_layeridx_sets.add(layeridx)
  753. if name_q in map_dict.keys():
  754. name_v = map_dict[name_q]["name"]
  755. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  756. data_tf = var_dict_tf[name_tf]
  757. if map_dict[name_q]["squeeze"] is not None:
  758. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  759. if map_dict[name_q]["transpose"] is not None:
  760. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  761. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  762. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  763. var_dict_torch[
  764. name].size(),
  765. data_tf.size())
  766. var_dict_torch_update[name] = data_tf
  767. logging.info(
  768. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  769. var_dict_tf[name_tf].shape))
  770. return var_dict_torch_update
  771. class ParaformerSANMDecoder(BaseTransformerDecoder):
  772. """
  773. Author: Speech Lab of DAMO Academy, Alibaba Group
  774. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  775. https://arxiv.org/abs/2006.01713
  776. """
  777. def __init__(
  778. self,
  779. vocab_size: int,
  780. encoder_output_size: int,
  781. attention_heads: int = 4,
  782. linear_units: int = 2048,
  783. num_blocks: int = 6,
  784. dropout_rate: float = 0.1,
  785. positional_dropout_rate: float = 0.1,
  786. self_attention_dropout_rate: float = 0.0,
  787. src_attention_dropout_rate: float = 0.0,
  788. input_layer: str = "embed",
  789. use_output_layer: bool = True,
  790. pos_enc_class=PositionalEncoding,
  791. normalize_before: bool = True,
  792. concat_after: bool = False,
  793. att_layer_num: int = 6,
  794. kernel_size: int = 21,
  795. sanm_shfit: int = 0,
  796. lora_list: List[str] = None,
  797. lora_rank: int = 8,
  798. lora_alpha: int = 16,
  799. lora_dropout: float = 0.1,
  800. chunk_multiply_factor: tuple = (1,),
  801. tf2torch_tensor_name_prefix_torch: str = "decoder",
  802. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  803. ):
  804. super().__init__(
  805. vocab_size=vocab_size,
  806. encoder_output_size=encoder_output_size,
  807. dropout_rate=dropout_rate,
  808. positional_dropout_rate=positional_dropout_rate,
  809. input_layer=input_layer,
  810. use_output_layer=use_output_layer,
  811. pos_enc_class=pos_enc_class,
  812. normalize_before=normalize_before,
  813. )
  814. attention_dim = encoder_output_size
  815. if input_layer == "embed":
  816. self.embed = torch.nn.Sequential(
  817. torch.nn.Embedding(vocab_size, attention_dim),
  818. # pos_enc_class(attention_dim, positional_dropout_rate),
  819. )
  820. elif input_layer == "linear":
  821. self.embed = torch.nn.Sequential(
  822. torch.nn.Linear(vocab_size, attention_dim),
  823. torch.nn.LayerNorm(attention_dim),
  824. torch.nn.Dropout(dropout_rate),
  825. torch.nn.ReLU(),
  826. pos_enc_class(attention_dim, positional_dropout_rate),
  827. )
  828. else:
  829. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  830. self.normalize_before = normalize_before
  831. if self.normalize_before:
  832. self.after_norm = LayerNorm(attention_dim)
  833. if use_output_layer:
  834. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  835. else:
  836. self.output_layer = None
  837. self.att_layer_num = att_layer_num
  838. self.num_blocks = num_blocks
  839. if sanm_shfit is None:
  840. sanm_shfit = (kernel_size - 1) // 2
  841. self.decoders = repeat(
  842. att_layer_num,
  843. lambda lnum: DecoderLayerSANM(
  844. attention_dim,
  845. MultiHeadedAttentionSANMDecoder(
  846. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  847. ),
  848. MultiHeadedAttentionCrossAtt(
  849. attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
  850. ),
  851. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  852. dropout_rate,
  853. normalize_before,
  854. concat_after,
  855. ),
  856. )
  857. if num_blocks - att_layer_num <= 0:
  858. self.decoders2 = None
  859. else:
  860. self.decoders2 = repeat(
  861. num_blocks - att_layer_num,
  862. lambda lnum: DecoderLayerSANM(
  863. attention_dim,
  864. MultiHeadedAttentionSANMDecoder(
  865. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  866. ),
  867. None,
  868. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  869. dropout_rate,
  870. normalize_before,
  871. concat_after,
  872. ),
  873. )
  874. self.decoders3 = repeat(
  875. 1,
  876. lambda lnum: DecoderLayerSANM(
  877. attention_dim,
  878. None,
  879. None,
  880. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  881. dropout_rate,
  882. normalize_before,
  883. concat_after,
  884. ),
  885. )
  886. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  887. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  888. self.chunk_multiply_factor = chunk_multiply_factor
  889. def forward(
  890. self,
  891. hs_pad: torch.Tensor,
  892. hlens: torch.Tensor,
  893. ys_in_pad: torch.Tensor,
  894. ys_in_lens: torch.Tensor,
  895. chunk_mask: torch.Tensor = None,
  896. ) -> Tuple[torch.Tensor, torch.Tensor]:
  897. """Forward decoder.
  898. Args:
  899. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  900. hlens: (batch)
  901. ys_in_pad:
  902. input token ids, int64 (batch, maxlen_out)
  903. if input_layer == "embed"
  904. input tensor (batch, maxlen_out, #mels) in the other cases
  905. ys_in_lens: (batch)
  906. Returns:
  907. (tuple): tuple containing:
  908. x: decoded token score before softmax (batch, maxlen_out, token)
  909. if use_output_layer is True,
  910. olens: (batch, )
  911. """
  912. tgt = ys_in_pad
  913. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  914. memory = hs_pad
  915. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  916. if chunk_mask is not None:
  917. memory_mask = memory_mask * chunk_mask
  918. if tgt_mask.size(1) != memory_mask.size(1):
  919. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  920. x = tgt
  921. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  922. x, tgt_mask, memory, memory_mask
  923. )
  924. if self.decoders2 is not None:
  925. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  926. x, tgt_mask, memory, memory_mask
  927. )
  928. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  929. x, tgt_mask, memory, memory_mask
  930. )
  931. if self.normalize_before:
  932. x = self.after_norm(x)
  933. if self.output_layer is not None:
  934. x = self.output_layer(x)
  935. olens = tgt_mask.sum(1)
  936. return x, olens
  937. def score(self, ys, state, x):
  938. """Score."""
  939. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  940. logp, state = self.forward_one_step(
  941. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  942. )
  943. return logp.squeeze(0), state
  944. def forward_chunk(
  945. self,
  946. memory: torch.Tensor,
  947. tgt: torch.Tensor,
  948. cache: dict = None,
  949. ) -> Tuple[torch.Tensor, torch.Tensor]:
  950. """Forward decoder.
  951. Args:
  952. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  953. hlens: (batch)
  954. ys_in_pad:
  955. input token ids, int64 (batch, maxlen_out)
  956. if input_layer == "embed"
  957. input tensor (batch, maxlen_out, #mels) in the other cases
  958. ys_in_lens: (batch)
  959. Returns:
  960. (tuple): tuple containing:
  961. x: decoded token score before softmax (batch, maxlen_out, token)
  962. if use_output_layer is True,
  963. olens: (batch, )
  964. """
  965. x = tgt
  966. if cache["decode_fsmn"] is None:
  967. cache_layer_num = len(self.decoders)
  968. if self.decoders2 is not None:
  969. cache_layer_num += len(self.decoders2)
  970. fsmn_cache = [None] * cache_layer_num
  971. else:
  972. fsmn_cache = cache["decode_fsmn"]
  973. if cache["opt"] is None:
  974. cache_layer_num = len(self.decoders)
  975. opt_cache = [None] * cache_layer_num
  976. else:
  977. opt_cache = cache["opt"]
  978. for i in range(self.att_layer_num):
  979. decoder = self.decoders[i]
  980. x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
  981. x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
  982. chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
  983. )
  984. if self.num_blocks - self.att_layer_num > 1:
  985. for i in range(self.num_blocks - self.att_layer_num):
  986. j = i + self.att_layer_num
  987. decoder = self.decoders2[i]
  988. x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
  989. x, memory, fsmn_cache=fsmn_cache[j]
  990. )
  991. for decoder in self.decoders3:
  992. x, memory, _, _ = decoder.forward_chunk(
  993. x, memory
  994. )
  995. if self.normalize_before:
  996. x = self.after_norm(x)
  997. if self.output_layer is not None:
  998. x = self.output_layer(x)
  999. cache["decode_fsmn"] = fsmn_cache
  1000. if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
  1001. cache["opt"] = opt_cache
  1002. return x
  1003. def forward_one_step(
  1004. self,
  1005. tgt: torch.Tensor,
  1006. tgt_mask: torch.Tensor,
  1007. memory: torch.Tensor,
  1008. cache: List[torch.Tensor] = None,
  1009. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  1010. """Forward one step.
  1011. Args:
  1012. tgt: input token ids, int64 (batch, maxlen_out)
  1013. tgt_mask: input token mask, (batch, maxlen_out)
  1014. dtype=torch.uint8 in PyTorch 1.2-
  1015. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  1016. memory: encoded memory, float32 (batch, maxlen_in, feat)
  1017. cache: cached output list of (batch, max_time_out-1, size)
  1018. Returns:
  1019. y, cache: NN output value and cache per `self.decoders`.
  1020. y.shape` is (batch, maxlen_out, token)
  1021. """
  1022. x = self.embed(tgt)
  1023. if cache is None:
  1024. cache_layer_num = len(self.decoders)
  1025. if self.decoders2 is not None:
  1026. cache_layer_num += len(self.decoders2)
  1027. cache = [None] * cache_layer_num
  1028. new_cache = []
  1029. # for c, decoder in zip(cache, self.decoders):
  1030. for i in range(self.att_layer_num):
  1031. decoder = self.decoders[i]
  1032. c = cache[i]
  1033. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  1034. x, tgt_mask, memory, None, cache=c
  1035. )
  1036. new_cache.append(c_ret)
  1037. if self.num_blocks - self.att_layer_num > 1:
  1038. for i in range(self.num_blocks - self.att_layer_num):
  1039. j = i + self.att_layer_num
  1040. decoder = self.decoders2[i]
  1041. c = cache[j]
  1042. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  1043. x, tgt_mask, memory, None, cache=c
  1044. )
  1045. new_cache.append(c_ret)
  1046. for decoder in self.decoders3:
  1047. x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
  1048. x, tgt_mask, memory, None, cache=None
  1049. )
  1050. if self.normalize_before:
  1051. y = self.after_norm(x[:, -1])
  1052. else:
  1053. y = x[:, -1]
  1054. if self.output_layer is not None:
  1055. y = torch.log_softmax(self.output_layer(y), dim=-1)
  1056. return y, new_cache
  1057. def gen_tf2torch_map_dict(self):
  1058. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  1059. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  1060. map_dict_local = {
  1061. ## decoder
  1062. # ffn
  1063. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1064. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1065. "squeeze": None,
  1066. "transpose": None,
  1067. }, # (256,),(256,)
  1068. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1069. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  1070. "squeeze": None,
  1071. "transpose": None,
  1072. }, # (256,),(256,)
  1073. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1074. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  1075. "squeeze": 0,
  1076. "transpose": (1, 0),
  1077. }, # (1024,256),(1,256,1024)
  1078. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1079. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  1080. "squeeze": None,
  1081. "transpose": None,
  1082. }, # (1024,),(1024,)
  1083. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1084. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1085. "squeeze": None,
  1086. "transpose": None,
  1087. }, # (1024,),(1024,)
  1088. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1089. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1090. "squeeze": None,
  1091. "transpose": None,
  1092. }, # (1024,),(1024,)
  1093. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1094. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1095. "squeeze": 0,
  1096. "transpose": (1, 0),
  1097. }, # (256,1024),(1,1024,256)
  1098. # fsmn
  1099. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  1100. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  1101. tensor_name_prefix_tf),
  1102. "squeeze": None,
  1103. "transpose": None,
  1104. }, # (256,),(256,)
  1105. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  1106. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  1107. tensor_name_prefix_tf),
  1108. "squeeze": None,
  1109. "transpose": None,
  1110. }, # (256,),(256,)
  1111. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  1112. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  1113. tensor_name_prefix_tf),
  1114. "squeeze": 0,
  1115. "transpose": (1, 2, 0),
  1116. }, # (256,1,31),(1,31,256,1)
  1117. # src att
  1118. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  1119. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1120. "squeeze": None,
  1121. "transpose": None,
  1122. }, # (256,),(256,)
  1123. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  1124. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  1125. "squeeze": None,
  1126. "transpose": None,
  1127. }, # (256,),(256,)
  1128. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  1129. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  1130. "squeeze": 0,
  1131. "transpose": (1, 0),
  1132. }, # (256,256),(1,256,256)
  1133. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  1134. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  1135. "squeeze": None,
  1136. "transpose": None,
  1137. }, # (256,),(256,)
  1138. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  1139. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1140. "squeeze": 0,
  1141. "transpose": (1, 0),
  1142. }, # (1024,256),(1,256,1024)
  1143. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  1144. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  1145. "squeeze": None,
  1146. "transpose": None,
  1147. }, # (1024,),(1024,)
  1148. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  1149. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  1150. "squeeze": 0,
  1151. "transpose": (1, 0),
  1152. }, # (256,256),(1,256,256)
  1153. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  1154. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  1155. "squeeze": None,
  1156. "transpose": None,
  1157. }, # (256,),(256,)
  1158. # dnn
  1159. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1160. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1161. "squeeze": None,
  1162. "transpose": None,
  1163. }, # (256,),(256,)
  1164. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1165. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  1166. "squeeze": None,
  1167. "transpose": None,
  1168. }, # (256,),(256,)
  1169. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1170. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  1171. "squeeze": 0,
  1172. "transpose": (1, 0),
  1173. }, # (1024,256),(1,256,1024)
  1174. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1175. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  1176. "squeeze": None,
  1177. "transpose": None,
  1178. }, # (1024,),(1024,)
  1179. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1180. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1181. "squeeze": None,
  1182. "transpose": None,
  1183. }, # (1024,),(1024,)
  1184. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1185. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1186. "squeeze": None,
  1187. "transpose": None,
  1188. }, # (1024,),(1024,)
  1189. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1190. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1191. "squeeze": 0,
  1192. "transpose": (1, 0),
  1193. }, # (256,1024),(1,1024,256)
  1194. # embed_concat_ffn
  1195. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  1196. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1197. "squeeze": None,
  1198. "transpose": None,
  1199. }, # (256,),(256,)
  1200. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  1201. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  1202. "squeeze": None,
  1203. "transpose": None,
  1204. }, # (256,),(256,)
  1205. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  1206. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  1207. "squeeze": 0,
  1208. "transpose": (1, 0),
  1209. }, # (1024,256),(1,256,1024)
  1210. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  1211. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  1212. "squeeze": None,
  1213. "transpose": None,
  1214. }, # (1024,),(1024,)
  1215. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  1216. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  1217. "squeeze": None,
  1218. "transpose": None,
  1219. }, # (1024,),(1024,)
  1220. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  1221. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  1222. "squeeze": None,
  1223. "transpose": None,
  1224. }, # (1024,),(1024,)
  1225. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  1226. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  1227. "squeeze": 0,
  1228. "transpose": (1, 0),
  1229. }, # (256,1024),(1,1024,256)
  1230. # out norm
  1231. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  1232. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  1233. "squeeze": None,
  1234. "transpose": None,
  1235. }, # (256,),(256,)
  1236. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  1237. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  1238. "squeeze": None,
  1239. "transpose": None,
  1240. }, # (256,),(256,)
  1241. # in embed
  1242. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  1243. {"name": "{}/w_embs".format(tensor_name_prefix_tf),
  1244. "squeeze": None,
  1245. "transpose": None,
  1246. }, # (4235,256),(4235,256)
  1247. # out layer
  1248. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  1249. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
  1250. "squeeze": [None, None],
  1251. "transpose": [(1, 0), None],
  1252. }, # (4235,256),(256,4235)
  1253. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  1254. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  1255. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  1256. "squeeze": [None, None],
  1257. "transpose": [None, None],
  1258. }, # (4235,),(4235,)
  1259. }
  1260. return map_dict_local
  1261. def convert_tf2torch(self,
  1262. var_dict_tf,
  1263. var_dict_torch,
  1264. ):
  1265. map_dict = self.gen_tf2torch_map_dict()
  1266. var_dict_torch_update = dict()
  1267. decoder_layeridx_sets = set()
  1268. for name in sorted(var_dict_torch.keys(), reverse=False):
  1269. names = name.split('.')
  1270. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  1271. if names[1] == "decoders":
  1272. layeridx = int(names[2])
  1273. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1274. layeridx_bias = 0
  1275. layeridx += layeridx_bias
  1276. decoder_layeridx_sets.add(layeridx)
  1277. if name_q in map_dict.keys():
  1278. name_v = map_dict[name_q]["name"]
  1279. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1280. data_tf = var_dict_tf[name_tf]
  1281. if map_dict[name_q]["squeeze"] is not None:
  1282. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1283. if map_dict[name_q]["transpose"] is not None:
  1284. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1285. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1286. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1287. var_dict_torch[
  1288. name].size(),
  1289. data_tf.size())
  1290. var_dict_torch_update[name] = data_tf
  1291. logging.info(
  1292. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1293. var_dict_tf[name_tf].shape))
  1294. elif names[1] == "decoders2":
  1295. layeridx = int(names[2])
  1296. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1297. name_q = name_q.replace("decoders2", "decoders")
  1298. layeridx_bias = len(decoder_layeridx_sets)
  1299. layeridx += layeridx_bias
  1300. if "decoders." in name:
  1301. decoder_layeridx_sets.add(layeridx)
  1302. if name_q in map_dict.keys():
  1303. name_v = map_dict[name_q]["name"]
  1304. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1305. data_tf = var_dict_tf[name_tf]
  1306. if map_dict[name_q]["squeeze"] is not None:
  1307. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1308. if map_dict[name_q]["transpose"] is not None:
  1309. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1310. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1311. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1312. var_dict_torch[
  1313. name].size(),
  1314. data_tf.size())
  1315. var_dict_torch_update[name] = data_tf
  1316. logging.info(
  1317. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1318. var_dict_tf[name_tf].shape))
  1319. elif names[1] == "decoders3":
  1320. layeridx = int(names[2])
  1321. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1322. layeridx_bias = 0
  1323. layeridx += layeridx_bias
  1324. if "decoders." in name:
  1325. decoder_layeridx_sets.add(layeridx)
  1326. if name_q in map_dict.keys():
  1327. name_v = map_dict[name_q]["name"]
  1328. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1329. data_tf = var_dict_tf[name_tf]
  1330. if map_dict[name_q]["squeeze"] is not None:
  1331. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1332. if map_dict[name_q]["transpose"] is not None:
  1333. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1334. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1335. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1336. var_dict_torch[
  1337. name].size(),
  1338. data_tf.size())
  1339. var_dict_torch_update[name] = data_tf
  1340. logging.info(
  1341. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1342. var_dict_tf[name_tf].shape))
  1343. elif names[1] == "embed" or names[1] == "output_layer":
  1344. name_tf = map_dict[name]["name"]
  1345. if isinstance(name_tf, list):
  1346. idx_list = 0
  1347. if name_tf[idx_list] in var_dict_tf.keys():
  1348. pass
  1349. else:
  1350. idx_list = 1
  1351. data_tf = var_dict_tf[name_tf[idx_list]]
  1352. if map_dict[name]["squeeze"][idx_list] is not None:
  1353. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  1354. if map_dict[name]["transpose"][idx_list] is not None:
  1355. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  1356. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1357. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1358. var_dict_torch[
  1359. name].size(),
  1360. data_tf.size())
  1361. var_dict_torch_update[name] = data_tf
  1362. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  1363. name_tf[idx_list],
  1364. var_dict_tf[name_tf[
  1365. idx_list]].shape))
  1366. else:
  1367. data_tf = var_dict_tf[name_tf]
  1368. if map_dict[name]["squeeze"] is not None:
  1369. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1370. if map_dict[name]["transpose"] is not None:
  1371. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1372. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1373. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1374. var_dict_torch[
  1375. name].size(),
  1376. data_tf.size())
  1377. var_dict_torch_update[name] = data_tf
  1378. logging.info(
  1379. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1380. var_dict_tf[name_tf].shape))
  1381. elif names[1] == "after_norm":
  1382. name_tf = map_dict[name]["name"]
  1383. data_tf = var_dict_tf[name_tf]
  1384. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1385. var_dict_torch_update[name] = data_tf
  1386. logging.info(
  1387. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1388. var_dict_tf[name_tf].shape))
  1389. elif names[1] == "embed_concat_ffn":
  1390. layeridx = int(names[2])
  1391. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  1392. layeridx_bias = 0
  1393. layeridx += layeridx_bias
  1394. if "decoders." in name:
  1395. decoder_layeridx_sets.add(layeridx)
  1396. if name_q in map_dict.keys():
  1397. name_v = map_dict[name_q]["name"]
  1398. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  1399. data_tf = var_dict_tf[name_tf]
  1400. if map_dict[name_q]["squeeze"] is not None:
  1401. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1402. if map_dict[name_q]["transpose"] is not None:
  1403. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1404. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1405. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1406. var_dict_torch[
  1407. name].size(),
  1408. data_tf.size())
  1409. var_dict_torch_update[name] = data_tf
  1410. logging.info(
  1411. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1412. var_dict_tf[name_tf].shape))
  1413. return var_dict_torch_update