encoder.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. from typing import List
  6. from typing import Optional
  7. from typing import Sequence
  8. from typing import Tuple
  9. from typing import Union
  10. import logging
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from funasr.models.scama.chunk_utilis import overlap_chunk
  15. import numpy as np
  16. from funasr.train_utils.device_funcs import to_device
  17. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  18. from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
  19. from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  20. from funasr.models.transformer.layer_norm import LayerNorm
  21. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  22. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  23. from funasr.models.transformer.positionwise_feed_forward import (
  24. PositionwiseFeedForward, # noqa: H301
  25. )
  26. from funasr.models.transformer.utils.repeat import repeat
  27. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  28. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  29. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  30. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  31. from funasr.models.transformer.utils.subsampling import TooShortUttError
  32. from funasr.models.transformer.utils.subsampling import check_short_utt
  33. from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
  34. from funasr.models.ctc.ctc import CTC
  35. from funasr.register import tables
  36. class EncoderLayerSANM(nn.Module):
  37. def __init__(
  38. self,
  39. in_size,
  40. size,
  41. self_attn,
  42. feed_forward,
  43. dropout_rate,
  44. normalize_before=True,
  45. concat_after=False,
  46. stochastic_depth_rate=0.0,
  47. ):
  48. """Construct an EncoderLayer object."""
  49. super(EncoderLayerSANM, self).__init__()
  50. self.self_attn = self_attn
  51. self.feed_forward = feed_forward
  52. self.norm1 = LayerNorm(in_size)
  53. self.norm2 = LayerNorm(size)
  54. self.dropout = nn.Dropout(dropout_rate)
  55. self.in_size = in_size
  56. self.size = size
  57. self.normalize_before = normalize_before
  58. self.concat_after = concat_after
  59. if self.concat_after:
  60. self.concat_linear = nn.Linear(size + size, size)
  61. self.stochastic_depth_rate = stochastic_depth_rate
  62. self.dropout_rate = dropout_rate
  63. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  64. """Compute encoded features.
  65. Args:
  66. x_input (torch.Tensor): Input tensor (#batch, time, size).
  67. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  68. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  69. Returns:
  70. torch.Tensor: Output tensor (#batch, time, size).
  71. torch.Tensor: Mask tensor (#batch, time).
  72. """
  73. skip_layer = False
  74. # with stochastic depth, residual connection `x + f(x)` becomes
  75. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  76. stoch_layer_coeff = 1.0
  77. if self.training and self.stochastic_depth_rate > 0:
  78. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  79. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  80. if skip_layer:
  81. if cache is not None:
  82. x = torch.cat([cache, x], dim=1)
  83. return x, mask
  84. residual = x
  85. if self.normalize_before:
  86. x = self.norm1(x)
  87. if self.concat_after:
  88. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  89. if self.in_size == self.size:
  90. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  91. else:
  92. x = stoch_layer_coeff * self.concat_linear(x_concat)
  93. else:
  94. if self.in_size == self.size:
  95. x = residual + stoch_layer_coeff * self.dropout(
  96. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  97. )
  98. else:
  99. x = stoch_layer_coeff * self.dropout(
  100. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  101. )
  102. if not self.normalize_before:
  103. x = self.norm1(x)
  104. residual = x
  105. if self.normalize_before:
  106. x = self.norm2(x)
  107. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  108. if not self.normalize_before:
  109. x = self.norm2(x)
  110. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  111. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  112. """Compute encoded features.
  113. Args:
  114. x_input (torch.Tensor): Input tensor (#batch, time, size).
  115. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  116. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  117. Returns:
  118. torch.Tensor: Output tensor (#batch, time, size).
  119. torch.Tensor: Mask tensor (#batch, time).
  120. """
  121. residual = x
  122. if self.normalize_before:
  123. x = self.norm1(x)
  124. if self.in_size == self.size:
  125. attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  126. x = residual + attn
  127. else:
  128. x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  129. if not self.normalize_before:
  130. x = self.norm1(x)
  131. residual = x
  132. if self.normalize_before:
  133. x = self.norm2(x)
  134. x = residual + self.feed_forward(x)
  135. if not self.normalize_before:
  136. x = self.norm2(x)
  137. return x, cache
  138. @tables.register("encoder_classes", "SANMEncoderChunkOpt")
  139. class SANMEncoderChunkOpt(nn.Module):
  140. """
  141. Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
  142. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  143. https://arxiv.org/abs/2006.01712
  144. """
  145. def __init__(
  146. self,
  147. input_size: int,
  148. output_size: int = 256,
  149. attention_heads: int = 4,
  150. linear_units: int = 2048,
  151. num_blocks: int = 6,
  152. dropout_rate: float = 0.1,
  153. positional_dropout_rate: float = 0.1,
  154. attention_dropout_rate: float = 0.0,
  155. input_layer: Optional[str] = "conv2d",
  156. pos_enc_class=SinusoidalPositionEncoder,
  157. normalize_before: bool = True,
  158. concat_after: bool = False,
  159. positionwise_layer_type: str = "linear",
  160. positionwise_conv_kernel_size: int = 1,
  161. padding_idx: int = -1,
  162. interctc_layer_idx: List[int] = [],
  163. interctc_use_conditioning: bool = False,
  164. kernel_size: int = 11,
  165. sanm_shfit: int = 0,
  166. selfattention_layer_type: str = "sanm",
  167. chunk_size: Union[int, Sequence[int]] = (16,),
  168. stride: Union[int, Sequence[int]] = (10,),
  169. pad_left: Union[int, Sequence[int]] = (0,),
  170. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  171. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  172. tf2torch_tensor_name_prefix_torch: str = "encoder",
  173. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  174. ):
  175. super().__init__()
  176. self._output_size = output_size
  177. if input_layer == "linear":
  178. self.embed = torch.nn.Sequential(
  179. torch.nn.Linear(input_size, output_size),
  180. torch.nn.LayerNorm(output_size),
  181. torch.nn.Dropout(dropout_rate),
  182. torch.nn.ReLU(),
  183. pos_enc_class(output_size, positional_dropout_rate),
  184. )
  185. elif input_layer == "conv2d":
  186. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  187. elif input_layer == "conv2d2":
  188. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  189. elif input_layer == "conv2d6":
  190. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  191. elif input_layer == "conv2d8":
  192. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  193. elif input_layer == "embed":
  194. self.embed = torch.nn.Sequential(
  195. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  196. pos_enc_class(output_size, positional_dropout_rate),
  197. )
  198. elif input_layer is None:
  199. if input_size == output_size:
  200. self.embed = None
  201. else:
  202. self.embed = torch.nn.Linear(input_size, output_size)
  203. elif input_layer == "pe":
  204. self.embed = SinusoidalPositionEncoder()
  205. elif input_layer == "pe_online":
  206. self.embed = StreamSinusoidalPositionEncoder()
  207. else:
  208. raise ValueError("unknown input_layer: " + input_layer)
  209. self.normalize_before = normalize_before
  210. if positionwise_layer_type == "linear":
  211. positionwise_layer = PositionwiseFeedForward
  212. positionwise_layer_args = (
  213. output_size,
  214. linear_units,
  215. dropout_rate,
  216. )
  217. elif positionwise_layer_type == "conv1d":
  218. positionwise_layer = MultiLayeredConv1d
  219. positionwise_layer_args = (
  220. output_size,
  221. linear_units,
  222. positionwise_conv_kernel_size,
  223. dropout_rate,
  224. )
  225. elif positionwise_layer_type == "conv1d-linear":
  226. positionwise_layer = Conv1dLinear
  227. positionwise_layer_args = (
  228. output_size,
  229. linear_units,
  230. positionwise_conv_kernel_size,
  231. dropout_rate,
  232. )
  233. else:
  234. raise NotImplementedError("Support only linear or conv1d.")
  235. if selfattention_layer_type == "selfattn":
  236. encoder_selfattn_layer = MultiHeadedAttention
  237. encoder_selfattn_layer_args = (
  238. attention_heads,
  239. output_size,
  240. attention_dropout_rate,
  241. )
  242. elif selfattention_layer_type == "sanm":
  243. encoder_selfattn_layer = MultiHeadedAttentionSANM
  244. encoder_selfattn_layer_args0 = (
  245. attention_heads,
  246. input_size,
  247. output_size,
  248. attention_dropout_rate,
  249. kernel_size,
  250. sanm_shfit,
  251. )
  252. encoder_selfattn_layer_args = (
  253. attention_heads,
  254. output_size,
  255. output_size,
  256. attention_dropout_rate,
  257. kernel_size,
  258. sanm_shfit,
  259. )
  260. self.encoders0 = repeat(
  261. 1,
  262. lambda lnum: EncoderLayerSANM(
  263. input_size,
  264. output_size,
  265. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  266. positionwise_layer(*positionwise_layer_args),
  267. dropout_rate,
  268. normalize_before,
  269. concat_after,
  270. ),
  271. )
  272. self.encoders = repeat(
  273. num_blocks - 1,
  274. lambda lnum: EncoderLayerSANM(
  275. output_size,
  276. output_size,
  277. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  278. positionwise_layer(*positionwise_layer_args),
  279. dropout_rate,
  280. normalize_before,
  281. concat_after,
  282. ),
  283. )
  284. if self.normalize_before:
  285. self.after_norm = LayerNorm(output_size)
  286. self.interctc_layer_idx = interctc_layer_idx
  287. if len(interctc_layer_idx) > 0:
  288. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  289. self.interctc_use_conditioning = interctc_use_conditioning
  290. self.conditioning_layer = None
  291. shfit_fsmn = (kernel_size - 1) // 2
  292. self.overlap_chunk_cls = overlap_chunk(
  293. chunk_size=chunk_size,
  294. stride=stride,
  295. pad_left=pad_left,
  296. shfit_fsmn=shfit_fsmn,
  297. encoder_att_look_back_factor=encoder_att_look_back_factor,
  298. decoder_att_look_back_factor=decoder_att_look_back_factor,
  299. )
  300. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  301. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  302. def output_size(self) -> int:
  303. return self._output_size
  304. def forward(
  305. self,
  306. xs_pad: torch.Tensor,
  307. ilens: torch.Tensor,
  308. prev_states: torch.Tensor = None,
  309. ctc: CTC = None,
  310. ind: int = 0,
  311. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  312. """Embed positions in tensor.
  313. Args:
  314. xs_pad: input tensor (B, L, D)
  315. ilens: input length (B)
  316. prev_states: Not to be used now.
  317. Returns:
  318. position embedded tensor and mask
  319. """
  320. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  321. xs_pad *= self.output_size() ** 0.5
  322. if self.embed is None:
  323. xs_pad = xs_pad
  324. elif (
  325. isinstance(self.embed, Conv2dSubsampling)
  326. or isinstance(self.embed, Conv2dSubsampling2)
  327. or isinstance(self.embed, Conv2dSubsampling6)
  328. or isinstance(self.embed, Conv2dSubsampling8)
  329. ):
  330. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  331. if short_status:
  332. raise TooShortUttError(
  333. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  334. + f"(it needs more than {limit_size} frames), return empty results",
  335. xs_pad.size(1),
  336. limit_size,
  337. )
  338. xs_pad, masks = self.embed(xs_pad, masks)
  339. else:
  340. xs_pad = self.embed(xs_pad)
  341. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  342. if self.overlap_chunk_cls is not None:
  343. ilens = masks.squeeze(1).sum(1)
  344. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  345. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  346. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  347. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  348. dtype=xs_pad.dtype)
  349. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  350. xs_pad.size(0),
  351. dtype=xs_pad.dtype)
  352. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  353. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  354. intermediate_outs = []
  355. if len(self.interctc_layer_idx) == 0:
  356. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  357. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  358. else:
  359. for layer_idx, encoder_layer in enumerate(self.encoders):
  360. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  361. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  362. if layer_idx + 1 in self.interctc_layer_idx:
  363. encoder_out = xs_pad
  364. # intermediate outputs are also normalized
  365. if self.normalize_before:
  366. encoder_out = self.after_norm(encoder_out)
  367. intermediate_outs.append((layer_idx + 1, encoder_out))
  368. if self.interctc_use_conditioning:
  369. ctc_out = ctc.softmax(encoder_out)
  370. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  371. if self.normalize_before:
  372. xs_pad = self.after_norm(xs_pad)
  373. olens = masks.squeeze(1).sum(1)
  374. if len(intermediate_outs) > 0:
  375. return (xs_pad, intermediate_outs), olens, None
  376. return xs_pad, olens, None
  377. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  378. if len(cache) == 0:
  379. return feats
  380. cache["feats"] = to_device(cache["feats"], device=feats.device)
  381. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  382. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  383. return overlap_feats
  384. def forward_chunk(self,
  385. xs_pad: torch.Tensor,
  386. ilens: torch.Tensor,
  387. cache: dict = None,
  388. **kwargs,
  389. ):
  390. is_final = kwargs.get("is_final", False)
  391. xs_pad *= self.output_size() ** 0.5
  392. if self.embed is None:
  393. xs_pad = xs_pad
  394. else:
  395. xs_pad = self.embed(xs_pad, cache)
  396. if cache["tail_chunk"]:
  397. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  398. else:
  399. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  400. if cache["opt"] is None:
  401. cache_layer_num = len(self.encoders0) + len(self.encoders)
  402. new_cache = [None] * cache_layer_num
  403. else:
  404. new_cache = cache["opt"]
  405. for layer_idx, encoder_layer in enumerate(self.encoders0):
  406. encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
  407. xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
  408. for layer_idx, encoder_layer in enumerate(self.encoders):
  409. encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
  410. xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
  411. if self.normalize_before:
  412. xs_pad = self.after_norm(xs_pad)
  413. if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
  414. cache["opt"] = new_cache
  415. return xs_pad, ilens, None
  416. def gen_tf2torch_map_dict(self):
  417. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  418. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  419. map_dict_local = {
  420. ## encoder
  421. # cicd
  422. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  423. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  424. "squeeze": None,
  425. "transpose": None,
  426. }, # (256,),(256,)
  427. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  428. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  429. "squeeze": None,
  430. "transpose": None,
  431. }, # (256,),(256,)
  432. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  433. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  434. "squeeze": 0,
  435. "transpose": (1, 0),
  436. }, # (768,256),(1,256,768)
  437. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  438. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  439. "squeeze": None,
  440. "transpose": None,
  441. }, # (768,),(768,)
  442. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  443. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  444. "squeeze": 0,
  445. "transpose": (1, 2, 0),
  446. }, # (256,1,31),(1,31,256,1)
  447. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  448. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  449. "squeeze": 0,
  450. "transpose": (1, 0),
  451. }, # (256,256),(1,256,256)
  452. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  453. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  454. "squeeze": None,
  455. "transpose": None,
  456. }, # (256,),(256,)
  457. # ffn
  458. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  459. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  460. "squeeze": None,
  461. "transpose": None,
  462. }, # (256,),(256,)
  463. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  464. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  465. "squeeze": None,
  466. "transpose": None,
  467. }, # (256,),(256,)
  468. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  469. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  470. "squeeze": 0,
  471. "transpose": (1, 0),
  472. }, # (1024,256),(1,256,1024)
  473. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  474. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  475. "squeeze": None,
  476. "transpose": None,
  477. }, # (1024,),(1024,)
  478. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  479. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  480. "squeeze": 0,
  481. "transpose": (1, 0),
  482. }, # (256,1024),(1,1024,256)
  483. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  484. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  485. "squeeze": None,
  486. "transpose": None,
  487. }, # (256,),(256,)
  488. # out norm
  489. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  490. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  491. "squeeze": None,
  492. "transpose": None,
  493. }, # (256,),(256,)
  494. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  495. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  496. "squeeze": None,
  497. "transpose": None,
  498. }, # (256,),(256,)
  499. }
  500. return map_dict_local
  501. def convert_tf2torch(self,
  502. var_dict_tf,
  503. var_dict_torch,
  504. ):
  505. map_dict = self.gen_tf2torch_map_dict()
  506. var_dict_torch_update = dict()
  507. for name in sorted(var_dict_torch.keys(), reverse=False):
  508. names = name.split('.')
  509. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  510. if names[1] == "encoders0":
  511. layeridx = int(names[2])
  512. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  513. name_q = name_q.replace("encoders0", "encoders")
  514. layeridx_bias = 0
  515. layeridx += layeridx_bias
  516. if name_q in map_dict.keys():
  517. name_v = map_dict[name_q]["name"]
  518. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  519. data_tf = var_dict_tf[name_tf]
  520. if map_dict[name_q]["squeeze"] is not None:
  521. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  522. if map_dict[name_q]["transpose"] is not None:
  523. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  524. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  525. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  526. var_dict_torch[
  527. name].size(),
  528. data_tf.size())
  529. var_dict_torch_update[name] = data_tf
  530. logging.info(
  531. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  532. var_dict_tf[name_tf].shape))
  533. elif names[1] == "encoders":
  534. layeridx = int(names[2])
  535. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  536. layeridx_bias = 1
  537. layeridx += layeridx_bias
  538. if name_q in map_dict.keys():
  539. name_v = map_dict[name_q]["name"]
  540. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  541. data_tf = var_dict_tf[name_tf]
  542. if map_dict[name_q]["squeeze"] is not None:
  543. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  544. if map_dict[name_q]["transpose"] is not None:
  545. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  546. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  547. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  548. var_dict_torch[
  549. name].size(),
  550. data_tf.size())
  551. var_dict_torch_update[name] = data_tf
  552. logging.info(
  553. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  554. var_dict_tf[name_tf].shape))
  555. elif names[1] == "after_norm":
  556. name_tf = map_dict[name]["name"]
  557. data_tf = var_dict_tf[name_tf]
  558. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  559. var_dict_torch_update[name] = data_tf
  560. logging.info(
  561. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  562. var_dict_tf[name_tf].shape))
  563. return var_dict_torch_update