sanm_encoder.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  11. import numpy as np
  12. from funasr.torch_utils.device_funcs import to_device
  13. from funasr.modules.nets_utils import make_pad_mask
  14. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  15. from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.multi_layer_conv import Conv1dLinear
  18. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  19. from funasr.modules.positionwise_feed_forward import (
  20. PositionwiseFeedForward, # noqa: H301
  21. )
  22. from funasr.modules.repeat import repeat
  23. from funasr.modules.subsampling import Conv2dSubsampling
  24. from funasr.modules.subsampling import Conv2dSubsampling2
  25. from funasr.modules.subsampling import Conv2dSubsampling6
  26. from funasr.modules.subsampling import Conv2dSubsampling8
  27. from funasr.modules.subsampling import TooShortUttError
  28. from funasr.modules.subsampling import check_short_utt
  29. from funasr.modules.mask import subsequent_mask, vad_mask
  30. from funasr.models.ctc import CTC
  31. from funasr.models.encoder.abs_encoder import AbsEncoder
  32. class EncoderLayerSANM(nn.Module):
  33. def __init__(
  34. self,
  35. in_size,
  36. size,
  37. self_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. stochastic_depth_rate=0.0,
  43. ):
  44. """Construct an EncoderLayer object."""
  45. super(EncoderLayerSANM, self).__init__()
  46. self.self_attn = self_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(in_size)
  49. self.norm2 = LayerNorm(size)
  50. self.dropout = nn.Dropout(dropout_rate)
  51. self.in_size = in_size
  52. self.size = size
  53. self.normalize_before = normalize_before
  54. self.concat_after = concat_after
  55. if self.concat_after:
  56. self.concat_linear = nn.Linear(size + size, size)
  57. self.stochastic_depth_rate = stochastic_depth_rate
  58. self.dropout_rate = dropout_rate
  59. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  60. """Compute encoded features.
  61. Args:
  62. x_input (torch.Tensor): Input tensor (#batch, time, size).
  63. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  64. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  65. Returns:
  66. torch.Tensor: Output tensor (#batch, time, size).
  67. torch.Tensor: Mask tensor (#batch, time).
  68. """
  69. skip_layer = False
  70. # with stochastic depth, residual connection `x + f(x)` becomes
  71. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  72. stoch_layer_coeff = 1.0
  73. if self.training and self.stochastic_depth_rate > 0:
  74. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  75. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  76. if skip_layer:
  77. if cache is not None:
  78. x = torch.cat([cache, x], dim=1)
  79. return x, mask
  80. residual = x
  81. if self.normalize_before:
  82. x = self.norm1(x)
  83. if self.concat_after:
  84. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  85. if self.in_size == self.size:
  86. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. x = stoch_layer_coeff * self.concat_linear(x_concat)
  89. else:
  90. if self.in_size == self.size:
  91. x = residual + stoch_layer_coeff * self.dropout(
  92. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  93. )
  94. else:
  95. x = stoch_layer_coeff * self.dropout(
  96. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  97. )
  98. if not self.normalize_before:
  99. x = self.norm1(x)
  100. residual = x
  101. if self.normalize_before:
  102. x = self.norm2(x)
  103. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  104. if not self.normalize_before:
  105. x = self.norm2(x)
  106. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  107. class SANMEncoder(AbsEncoder):
  108. """
  109. Author: Speech Lab of DAMO Academy, Alibaba Group
  110. San-m: Memory equipped self-attention for end-to-end speech recognition
  111. https://arxiv.org/abs/2006.01713
  112. """
  113. def __init__(
  114. self,
  115. input_size: int,
  116. output_size: int = 256,
  117. attention_heads: int = 4,
  118. linear_units: int = 2048,
  119. num_blocks: int = 6,
  120. dropout_rate: float = 0.1,
  121. positional_dropout_rate: float = 0.1,
  122. attention_dropout_rate: float = 0.0,
  123. input_layer: Optional[str] = "conv2d",
  124. pos_enc_class=SinusoidalPositionEncoder,
  125. normalize_before: bool = True,
  126. concat_after: bool = False,
  127. positionwise_layer_type: str = "linear",
  128. positionwise_conv_kernel_size: int = 1,
  129. padding_idx: int = -1,
  130. interctc_layer_idx: List[int] = [],
  131. interctc_use_conditioning: bool = False,
  132. kernel_size : int = 11,
  133. sanm_shfit : int = 0,
  134. selfattention_layer_type: str = "sanm",
  135. tf2torch_tensor_name_prefix_torch: str = "encoder",
  136. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  137. ):
  138. super().__init__()
  139. self._output_size = output_size
  140. if input_layer == "linear":
  141. self.embed = torch.nn.Sequential(
  142. torch.nn.Linear(input_size, output_size),
  143. torch.nn.LayerNorm(output_size),
  144. torch.nn.Dropout(dropout_rate),
  145. torch.nn.ReLU(),
  146. pos_enc_class(output_size, positional_dropout_rate),
  147. )
  148. elif input_layer == "conv2d":
  149. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  150. elif input_layer == "conv2d2":
  151. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  152. elif input_layer == "conv2d6":
  153. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  154. elif input_layer == "conv2d8":
  155. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  156. elif input_layer == "embed":
  157. self.embed = torch.nn.Sequential(
  158. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  159. SinusoidalPositionEncoder(),
  160. )
  161. elif input_layer is None:
  162. if input_size == output_size:
  163. self.embed = None
  164. else:
  165. self.embed = torch.nn.Linear(input_size, output_size)
  166. elif input_layer == "pe":
  167. self.embed = SinusoidalPositionEncoder()
  168. elif input_layer == "pe_online":
  169. self.embed = StreamSinusoidalPositionEncoder()
  170. else:
  171. raise ValueError("unknown input_layer: " + input_layer)
  172. self.normalize_before = normalize_before
  173. if positionwise_layer_type == "linear":
  174. positionwise_layer = PositionwiseFeedForward
  175. positionwise_layer_args = (
  176. output_size,
  177. linear_units,
  178. dropout_rate,
  179. )
  180. elif positionwise_layer_type == "conv1d":
  181. positionwise_layer = MultiLayeredConv1d
  182. positionwise_layer_args = (
  183. output_size,
  184. linear_units,
  185. positionwise_conv_kernel_size,
  186. dropout_rate,
  187. )
  188. elif positionwise_layer_type == "conv1d-linear":
  189. positionwise_layer = Conv1dLinear
  190. positionwise_layer_args = (
  191. output_size,
  192. linear_units,
  193. positionwise_conv_kernel_size,
  194. dropout_rate,
  195. )
  196. else:
  197. raise NotImplementedError("Support only linear or conv1d.")
  198. if selfattention_layer_type == "selfattn":
  199. encoder_selfattn_layer = MultiHeadedAttention
  200. encoder_selfattn_layer_args = (
  201. attention_heads,
  202. output_size,
  203. attention_dropout_rate,
  204. )
  205. elif selfattention_layer_type == "sanm":
  206. encoder_selfattn_layer = MultiHeadedAttentionSANM
  207. encoder_selfattn_layer_args0 = (
  208. attention_heads,
  209. input_size,
  210. output_size,
  211. attention_dropout_rate,
  212. kernel_size,
  213. sanm_shfit,
  214. )
  215. encoder_selfattn_layer_args = (
  216. attention_heads,
  217. output_size,
  218. output_size,
  219. attention_dropout_rate,
  220. kernel_size,
  221. sanm_shfit,
  222. )
  223. self.encoders0 = repeat(
  224. 1,
  225. lambda lnum: EncoderLayerSANM(
  226. input_size,
  227. output_size,
  228. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  229. positionwise_layer(*positionwise_layer_args),
  230. dropout_rate,
  231. normalize_before,
  232. concat_after,
  233. ),
  234. )
  235. self.encoders = repeat(
  236. num_blocks-1,
  237. lambda lnum: EncoderLayerSANM(
  238. output_size,
  239. output_size,
  240. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  241. positionwise_layer(*positionwise_layer_args),
  242. dropout_rate,
  243. normalize_before,
  244. concat_after,
  245. ),
  246. )
  247. if self.normalize_before:
  248. self.after_norm = LayerNorm(output_size)
  249. self.interctc_layer_idx = interctc_layer_idx
  250. if len(interctc_layer_idx) > 0:
  251. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  252. self.interctc_use_conditioning = interctc_use_conditioning
  253. self.conditioning_layer = None
  254. self.dropout = nn.Dropout(dropout_rate)
  255. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  256. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  257. def output_size(self) -> int:
  258. return self._output_size
  259. def forward(
  260. self,
  261. xs_pad: torch.Tensor,
  262. ilens: torch.Tensor,
  263. prev_states: torch.Tensor = None,
  264. ctc: CTC = None,
  265. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  266. """Embed positions in tensor.
  267. Args:
  268. xs_pad: input tensor (B, L, D)
  269. ilens: input length (B)
  270. prev_states: Not to be used now.
  271. Returns:
  272. position embedded tensor and mask
  273. """
  274. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  275. xs_pad = xs_pad * self.output_size()**0.5
  276. if self.embed is None:
  277. xs_pad = xs_pad
  278. elif (
  279. isinstance(self.embed, Conv2dSubsampling)
  280. or isinstance(self.embed, Conv2dSubsampling2)
  281. or isinstance(self.embed, Conv2dSubsampling6)
  282. or isinstance(self.embed, Conv2dSubsampling8)
  283. ):
  284. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  285. if short_status:
  286. raise TooShortUttError(
  287. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  288. + f"(it needs more than {limit_size} frames), return empty results",
  289. xs_pad.size(1),
  290. limit_size,
  291. )
  292. xs_pad, masks = self.embed(xs_pad, masks)
  293. else:
  294. xs_pad = self.embed(xs_pad)
  295. # xs_pad = self.dropout(xs_pad)
  296. encoder_outs = self.encoders0(xs_pad, masks)
  297. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  298. intermediate_outs = []
  299. if len(self.interctc_layer_idx) == 0:
  300. encoder_outs = self.encoders(xs_pad, masks)
  301. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  302. else:
  303. for layer_idx, encoder_layer in enumerate(self.encoders):
  304. encoder_outs = encoder_layer(xs_pad, masks)
  305. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  306. if layer_idx + 1 in self.interctc_layer_idx:
  307. encoder_out = xs_pad
  308. # intermediate outputs are also normalized
  309. if self.normalize_before:
  310. encoder_out = self.after_norm(encoder_out)
  311. intermediate_outs.append((layer_idx + 1, encoder_out))
  312. if self.interctc_use_conditioning:
  313. ctc_out = ctc.softmax(encoder_out)
  314. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  315. if self.normalize_before:
  316. xs_pad = self.after_norm(xs_pad)
  317. olens = masks.squeeze(1).sum(1)
  318. if len(intermediate_outs) > 0:
  319. return (xs_pad, intermediate_outs), olens, None
  320. return xs_pad, olens, None
  321. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  322. if len(cache) == 0:
  323. return feats
  324. cache["feats"] = to_device(cache["feats"], device=feats.device)
  325. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  326. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  327. return overlap_feats
  328. def forward_chunk(self,
  329. xs_pad: torch.Tensor,
  330. ilens: torch.Tensor,
  331. cache: dict = None,
  332. ctc: CTC = None,
  333. ):
  334. xs_pad *= self.output_size() ** 0.5
  335. if self.embed is None:
  336. xs_pad = xs_pad
  337. else:
  338. xs_pad = self.embed(xs_pad, cache)
  339. if cache["tail_chunk"]:
  340. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  341. else:
  342. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  343. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  344. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  345. intermediate_outs = []
  346. if len(self.interctc_layer_idx) == 0:
  347. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  348. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  349. else:
  350. for layer_idx, encoder_layer in enumerate(self.encoders):
  351. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  352. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  353. if layer_idx + 1 in self.interctc_layer_idx:
  354. encoder_out = xs_pad
  355. # intermediate outputs are also normalized
  356. if self.normalize_before:
  357. encoder_out = self.after_norm(encoder_out)
  358. intermediate_outs.append((layer_idx + 1, encoder_out))
  359. if self.interctc_use_conditioning:
  360. ctc_out = ctc.softmax(encoder_out)
  361. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  362. if self.normalize_before:
  363. xs_pad = self.after_norm(xs_pad)
  364. if len(intermediate_outs) > 0:
  365. return (xs_pad, intermediate_outs), None, None
  366. return xs_pad, ilens, None
  367. def gen_tf2torch_map_dict(self):
  368. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  369. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  370. map_dict_local = {
  371. ## encoder
  372. # cicd
  373. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  374. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  375. "squeeze": None,
  376. "transpose": None,
  377. }, # (256,),(256,)
  378. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  379. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  380. "squeeze": None,
  381. "transpose": None,
  382. }, # (256,),(256,)
  383. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  384. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  385. "squeeze": 0,
  386. "transpose": (1, 0),
  387. }, # (768,256),(1,256,768)
  388. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  389. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  390. "squeeze": None,
  391. "transpose": None,
  392. }, # (768,),(768,)
  393. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  394. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  395. "squeeze": 0,
  396. "transpose": (1, 2, 0),
  397. }, # (256,1,31),(1,31,256,1)
  398. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  399. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  400. "squeeze": 0,
  401. "transpose": (1, 0),
  402. }, # (256,256),(1,256,256)
  403. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  404. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  405. "squeeze": None,
  406. "transpose": None,
  407. }, # (256,),(256,)
  408. # ffn
  409. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  410. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  411. "squeeze": None,
  412. "transpose": None,
  413. }, # (256,),(256,)
  414. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  415. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  416. "squeeze": None,
  417. "transpose": None,
  418. }, # (256,),(256,)
  419. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  420. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  421. "squeeze": 0,
  422. "transpose": (1, 0),
  423. }, # (1024,256),(1,256,1024)
  424. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  425. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  426. "squeeze": None,
  427. "transpose": None,
  428. }, # (1024,),(1024,)
  429. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  430. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  431. "squeeze": 0,
  432. "transpose": (1, 0),
  433. }, # (256,1024),(1,1024,256)
  434. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  435. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  436. "squeeze": None,
  437. "transpose": None,
  438. }, # (256,),(256,)
  439. # out norm
  440. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  441. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  442. "squeeze": None,
  443. "transpose": None,
  444. }, # (256,),(256,)
  445. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  446. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  447. "squeeze": None,
  448. "transpose": None,
  449. }, # (256,),(256,)
  450. }
  451. return map_dict_local
  452. def convert_tf2torch(self,
  453. var_dict_tf,
  454. var_dict_torch,
  455. ):
  456. map_dict = self.gen_tf2torch_map_dict()
  457. var_dict_torch_update = dict()
  458. for name in sorted(var_dict_torch.keys(), reverse=False):
  459. names = name.split('.')
  460. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  461. if names[1] == "encoders0":
  462. layeridx = int(names[2])
  463. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  464. name_q = name_q.replace("encoders0", "encoders")
  465. layeridx_bias = 0
  466. layeridx += layeridx_bias
  467. if name_q in map_dict.keys():
  468. name_v = map_dict[name_q]["name"]
  469. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  470. data_tf = var_dict_tf[name_tf]
  471. if map_dict[name_q]["squeeze"] is not None:
  472. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  473. if map_dict[name_q]["transpose"] is not None:
  474. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  475. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  476. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  477. var_dict_torch[
  478. name].size(),
  479. data_tf.size())
  480. var_dict_torch_update[name] = data_tf
  481. logging.info(
  482. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  483. var_dict_tf[name_tf].shape))
  484. elif names[1] == "encoders":
  485. layeridx = int(names[2])
  486. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  487. layeridx_bias = 1
  488. layeridx += layeridx_bias
  489. if name_q in map_dict.keys():
  490. name_v = map_dict[name_q]["name"]
  491. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  492. data_tf = var_dict_tf[name_tf]
  493. if map_dict[name_q]["squeeze"] is not None:
  494. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  495. if map_dict[name_q]["transpose"] is not None:
  496. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  497. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  498. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  499. var_dict_torch[
  500. name].size(),
  501. data_tf.size())
  502. var_dict_torch_update[name] = data_tf
  503. logging.info(
  504. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  505. var_dict_tf[name_tf].shape))
  506. elif names[1] == "after_norm":
  507. name_tf = map_dict[name]["name"]
  508. data_tf = var_dict_tf[name_tf]
  509. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  510. var_dict_torch_update[name] = data_tf
  511. logging.info(
  512. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  513. var_dict_tf[name_tf].shape))
  514. return var_dict_torch_update
  515. class SANMEncoderChunkOpt(AbsEncoder):
  516. """
  517. Author: Speech Lab of DAMO Academy, Alibaba Group
  518. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  519. https://arxiv.org/abs/2006.01713
  520. """
  521. def __init__(
  522. self,
  523. input_size: int,
  524. output_size: int = 256,
  525. attention_heads: int = 4,
  526. linear_units: int = 2048,
  527. num_blocks: int = 6,
  528. dropout_rate: float = 0.1,
  529. positional_dropout_rate: float = 0.1,
  530. attention_dropout_rate: float = 0.0,
  531. input_layer: Optional[str] = "conv2d",
  532. pos_enc_class=SinusoidalPositionEncoder,
  533. normalize_before: bool = True,
  534. concat_after: bool = False,
  535. positionwise_layer_type: str = "linear",
  536. positionwise_conv_kernel_size: int = 1,
  537. padding_idx: int = -1,
  538. interctc_layer_idx: List[int] = [],
  539. interctc_use_conditioning: bool = False,
  540. kernel_size: int = 11,
  541. sanm_shfit: int = 0,
  542. selfattention_layer_type: str = "sanm",
  543. chunk_size: Union[int, Sequence[int]] = (16,),
  544. stride: Union[int, Sequence[int]] = (10,),
  545. pad_left: Union[int, Sequence[int]] = (0,),
  546. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  547. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  548. tf2torch_tensor_name_prefix_torch: str = "encoder",
  549. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  550. ):
  551. super().__init__()
  552. self._output_size = output_size
  553. if input_layer == "linear":
  554. self.embed = torch.nn.Sequential(
  555. torch.nn.Linear(input_size, output_size),
  556. torch.nn.LayerNorm(output_size),
  557. torch.nn.Dropout(dropout_rate),
  558. torch.nn.ReLU(),
  559. pos_enc_class(output_size, positional_dropout_rate),
  560. )
  561. elif input_layer == "conv2d":
  562. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  563. elif input_layer == "conv2d2":
  564. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  565. elif input_layer == "conv2d6":
  566. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  567. elif input_layer == "conv2d8":
  568. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  569. elif input_layer == "embed":
  570. self.embed = torch.nn.Sequential(
  571. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  572. pos_enc_class(output_size, positional_dropout_rate),
  573. )
  574. elif input_layer is None:
  575. if input_size == output_size:
  576. self.embed = None
  577. else:
  578. self.embed = torch.nn.Linear(input_size, output_size)
  579. elif input_layer == "pe":
  580. self.embed = SinusoidalPositionEncoder()
  581. elif input_layer == "pe_online":
  582. self.embed = StreamSinusoidalPositionEncoder()
  583. else:
  584. raise ValueError("unknown input_layer: " + input_layer)
  585. self.normalize_before = normalize_before
  586. if positionwise_layer_type == "linear":
  587. positionwise_layer = PositionwiseFeedForward
  588. positionwise_layer_args = (
  589. output_size,
  590. linear_units,
  591. dropout_rate,
  592. )
  593. elif positionwise_layer_type == "conv1d":
  594. positionwise_layer = MultiLayeredConv1d
  595. positionwise_layer_args = (
  596. output_size,
  597. linear_units,
  598. positionwise_conv_kernel_size,
  599. dropout_rate,
  600. )
  601. elif positionwise_layer_type == "conv1d-linear":
  602. positionwise_layer = Conv1dLinear
  603. positionwise_layer_args = (
  604. output_size,
  605. linear_units,
  606. positionwise_conv_kernel_size,
  607. dropout_rate,
  608. )
  609. else:
  610. raise NotImplementedError("Support only linear or conv1d.")
  611. if selfattention_layer_type == "selfattn":
  612. encoder_selfattn_layer = MultiHeadedAttention
  613. encoder_selfattn_layer_args = (
  614. attention_heads,
  615. output_size,
  616. attention_dropout_rate,
  617. )
  618. elif selfattention_layer_type == "sanm":
  619. encoder_selfattn_layer = MultiHeadedAttentionSANM
  620. encoder_selfattn_layer_args0 = (
  621. attention_heads,
  622. input_size,
  623. output_size,
  624. attention_dropout_rate,
  625. kernel_size,
  626. sanm_shfit,
  627. )
  628. encoder_selfattn_layer_args = (
  629. attention_heads,
  630. output_size,
  631. output_size,
  632. attention_dropout_rate,
  633. kernel_size,
  634. sanm_shfit,
  635. )
  636. self.encoders0 = repeat(
  637. 1,
  638. lambda lnum: EncoderLayerSANM(
  639. input_size,
  640. output_size,
  641. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  642. positionwise_layer(*positionwise_layer_args),
  643. dropout_rate,
  644. normalize_before,
  645. concat_after,
  646. ),
  647. )
  648. self.encoders = repeat(
  649. num_blocks - 1,
  650. lambda lnum: EncoderLayerSANM(
  651. output_size,
  652. output_size,
  653. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  654. positionwise_layer(*positionwise_layer_args),
  655. dropout_rate,
  656. normalize_before,
  657. concat_after,
  658. ),
  659. )
  660. if self.normalize_before:
  661. self.after_norm = LayerNorm(output_size)
  662. self.interctc_layer_idx = interctc_layer_idx
  663. if len(interctc_layer_idx) > 0:
  664. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  665. self.interctc_use_conditioning = interctc_use_conditioning
  666. self.conditioning_layer = None
  667. shfit_fsmn = (kernel_size - 1) // 2
  668. self.overlap_chunk_cls = overlap_chunk(
  669. chunk_size=chunk_size,
  670. stride=stride,
  671. pad_left=pad_left,
  672. shfit_fsmn=shfit_fsmn,
  673. encoder_att_look_back_factor=encoder_att_look_back_factor,
  674. decoder_att_look_back_factor=decoder_att_look_back_factor,
  675. )
  676. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  677. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  678. def output_size(self) -> int:
  679. return self._output_size
  680. def forward(
  681. self,
  682. xs_pad: torch.Tensor,
  683. ilens: torch.Tensor,
  684. prev_states: torch.Tensor = None,
  685. ctc: CTC = None,
  686. ind: int = 0,
  687. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  688. """Embed positions in tensor.
  689. Args:
  690. xs_pad: input tensor (B, L, D)
  691. ilens: input length (B)
  692. prev_states: Not to be used now.
  693. Returns:
  694. position embedded tensor and mask
  695. """
  696. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  697. xs_pad *= self.output_size() ** 0.5
  698. if self.embed is None:
  699. xs_pad = xs_pad
  700. elif (
  701. isinstance(self.embed, Conv2dSubsampling)
  702. or isinstance(self.embed, Conv2dSubsampling2)
  703. or isinstance(self.embed, Conv2dSubsampling6)
  704. or isinstance(self.embed, Conv2dSubsampling8)
  705. ):
  706. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  707. if short_status:
  708. raise TooShortUttError(
  709. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  710. + f"(it needs more than {limit_size} frames), return empty results",
  711. xs_pad.size(1),
  712. limit_size,
  713. )
  714. xs_pad, masks = self.embed(xs_pad, masks)
  715. else:
  716. xs_pad = self.embed(xs_pad)
  717. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  718. if self.overlap_chunk_cls is not None:
  719. ilens = masks.squeeze(1).sum(1)
  720. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  721. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  722. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  723. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  724. dtype=xs_pad.dtype)
  725. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  726. xs_pad.size(0),
  727. dtype=xs_pad.dtype)
  728. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  729. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  730. intermediate_outs = []
  731. if len(self.interctc_layer_idx) == 0:
  732. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  733. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  734. else:
  735. for layer_idx, encoder_layer in enumerate(self.encoders):
  736. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  737. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  738. if layer_idx + 1 in self.interctc_layer_idx:
  739. encoder_out = xs_pad
  740. # intermediate outputs are also normalized
  741. if self.normalize_before:
  742. encoder_out = self.after_norm(encoder_out)
  743. intermediate_outs.append((layer_idx + 1, encoder_out))
  744. if self.interctc_use_conditioning:
  745. ctc_out = ctc.softmax(encoder_out)
  746. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  747. if self.normalize_before:
  748. xs_pad = self.after_norm(xs_pad)
  749. olens = masks.squeeze(1).sum(1)
  750. if len(intermediate_outs) > 0:
  751. return (xs_pad, intermediate_outs), olens, None
  752. return xs_pad, olens, None
  753. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  754. if len(cache) == 0:
  755. return feats
  756. cache["feats"] = to_device(cache["feats"], device=feats.device)
  757. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  758. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  759. return overlap_feats
  760. def forward_chunk(self,
  761. xs_pad: torch.Tensor,
  762. ilens: torch.Tensor,
  763. cache: dict = None,
  764. ctc: CTC = None,
  765. ):
  766. xs_pad *= self.output_size() ** 0.5
  767. if self.embed is None:
  768. xs_pad = xs_pad
  769. else:
  770. xs_pad = self.embed(xs_pad, cache)
  771. if cache["tail_chunk"]:
  772. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  773. else:
  774. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  775. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  776. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  777. intermediate_outs = []
  778. if len(self.interctc_layer_idx) == 0:
  779. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  780. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  781. else:
  782. for layer_idx, encoder_layer in enumerate(self.encoders):
  783. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  784. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  785. if layer_idx + 1 in self.interctc_layer_idx:
  786. encoder_out = xs_pad
  787. # intermediate outputs are also normalized
  788. if self.normalize_before:
  789. encoder_out = self.after_norm(encoder_out)
  790. intermediate_outs.append((layer_idx + 1, encoder_out))
  791. if self.interctc_use_conditioning:
  792. ctc_out = ctc.softmax(encoder_out)
  793. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  794. if self.normalize_before:
  795. xs_pad = self.after_norm(xs_pad)
  796. if len(intermediate_outs) > 0:
  797. return (xs_pad, intermediate_outs), None, None
  798. return xs_pad, ilens, None
  799. def gen_tf2torch_map_dict(self):
  800. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  801. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  802. map_dict_local = {
  803. ## encoder
  804. # cicd
  805. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  806. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  807. "squeeze": None,
  808. "transpose": None,
  809. }, # (256,),(256,)
  810. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  811. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  812. "squeeze": None,
  813. "transpose": None,
  814. }, # (256,),(256,)
  815. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  816. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  817. "squeeze": 0,
  818. "transpose": (1, 0),
  819. }, # (768,256),(1,256,768)
  820. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  821. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  822. "squeeze": None,
  823. "transpose": None,
  824. }, # (768,),(768,)
  825. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  826. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  827. "squeeze": 0,
  828. "transpose": (1, 2, 0),
  829. }, # (256,1,31),(1,31,256,1)
  830. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  831. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  832. "squeeze": 0,
  833. "transpose": (1, 0),
  834. }, # (256,256),(1,256,256)
  835. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  836. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  837. "squeeze": None,
  838. "transpose": None,
  839. }, # (256,),(256,)
  840. # ffn
  841. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  842. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  843. "squeeze": None,
  844. "transpose": None,
  845. }, # (256,),(256,)
  846. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  847. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  848. "squeeze": None,
  849. "transpose": None,
  850. }, # (256,),(256,)
  851. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  852. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  853. "squeeze": 0,
  854. "transpose": (1, 0),
  855. }, # (1024,256),(1,256,1024)
  856. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  857. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  858. "squeeze": None,
  859. "transpose": None,
  860. }, # (1024,),(1024,)
  861. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  862. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  863. "squeeze": 0,
  864. "transpose": (1, 0),
  865. }, # (256,1024),(1,1024,256)
  866. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  867. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  868. "squeeze": None,
  869. "transpose": None,
  870. }, # (256,),(256,)
  871. # out norm
  872. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  873. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  874. "squeeze": None,
  875. "transpose": None,
  876. }, # (256,),(256,)
  877. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  878. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  879. "squeeze": None,
  880. "transpose": None,
  881. }, # (256,),(256,)
  882. }
  883. return map_dict_local
  884. def convert_tf2torch(self,
  885. var_dict_tf,
  886. var_dict_torch,
  887. ):
  888. map_dict = self.gen_tf2torch_map_dict()
  889. var_dict_torch_update = dict()
  890. for name in sorted(var_dict_torch.keys(), reverse=False):
  891. names = name.split('.')
  892. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  893. if names[1] == "encoders0":
  894. layeridx = int(names[2])
  895. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  896. name_q = name_q.replace("encoders0", "encoders")
  897. layeridx_bias = 0
  898. layeridx += layeridx_bias
  899. if name_q in map_dict.keys():
  900. name_v = map_dict[name_q]["name"]
  901. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  902. data_tf = var_dict_tf[name_tf]
  903. if map_dict[name_q]["squeeze"] is not None:
  904. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  905. if map_dict[name_q]["transpose"] is not None:
  906. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  907. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  908. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  909. var_dict_torch[
  910. name].size(),
  911. data_tf.size())
  912. var_dict_torch_update[name] = data_tf
  913. logging.info(
  914. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  915. var_dict_tf[name_tf].shape))
  916. elif names[1] == "encoders":
  917. layeridx = int(names[2])
  918. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  919. layeridx_bias = 1
  920. layeridx += layeridx_bias
  921. if name_q in map_dict.keys():
  922. name_v = map_dict[name_q]["name"]
  923. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  924. data_tf = var_dict_tf[name_tf]
  925. if map_dict[name_q]["squeeze"] is not None:
  926. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  927. if map_dict[name_q]["transpose"] is not None:
  928. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  929. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  930. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  931. var_dict_torch[
  932. name].size(),
  933. data_tf.size())
  934. var_dict_torch_update[name] = data_tf
  935. logging.info(
  936. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  937. var_dict_tf[name_tf].shape))
  938. elif names[1] == "after_norm":
  939. name_tf = map_dict[name]["name"]
  940. data_tf = var_dict_tf[name_tf]
  941. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  942. var_dict_torch_update[name] = data_tf
  943. logging.info(
  944. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  945. var_dict_tf[name_tf].shape))
  946. return var_dict_torch_update
  947. class SANMVadEncoder(AbsEncoder):
  948. """
  949. Author: Speech Lab of DAMO Academy, Alibaba Group
  950. """
  951. def __init__(
  952. self,
  953. input_size: int,
  954. output_size: int = 256,
  955. attention_heads: int = 4,
  956. linear_units: int = 2048,
  957. num_blocks: int = 6,
  958. dropout_rate: float = 0.1,
  959. positional_dropout_rate: float = 0.1,
  960. attention_dropout_rate: float = 0.0,
  961. input_layer: Optional[str] = "conv2d",
  962. pos_enc_class=SinusoidalPositionEncoder,
  963. normalize_before: bool = True,
  964. concat_after: bool = False,
  965. positionwise_layer_type: str = "linear",
  966. positionwise_conv_kernel_size: int = 1,
  967. padding_idx: int = -1,
  968. interctc_layer_idx: List[int] = [],
  969. interctc_use_conditioning: bool = False,
  970. kernel_size : int = 11,
  971. sanm_shfit : int = 0,
  972. selfattention_layer_type: str = "sanm",
  973. ):
  974. super().__init__()
  975. self._output_size = output_size
  976. if input_layer == "linear":
  977. self.embed = torch.nn.Sequential(
  978. torch.nn.Linear(input_size, output_size),
  979. torch.nn.LayerNorm(output_size),
  980. torch.nn.Dropout(dropout_rate),
  981. torch.nn.ReLU(),
  982. pos_enc_class(output_size, positional_dropout_rate),
  983. )
  984. elif input_layer == "conv2d":
  985. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  986. elif input_layer == "conv2d2":
  987. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  988. elif input_layer == "conv2d6":
  989. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  990. elif input_layer == "conv2d8":
  991. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  992. elif input_layer == "embed":
  993. self.embed = torch.nn.Sequential(
  994. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  995. SinusoidalPositionEncoder(),
  996. )
  997. elif input_layer is None:
  998. if input_size == output_size:
  999. self.embed = None
  1000. else:
  1001. self.embed = torch.nn.Linear(input_size, output_size)
  1002. elif input_layer == "pe":
  1003. self.embed = SinusoidalPositionEncoder()
  1004. else:
  1005. raise ValueError("unknown input_layer: " + input_layer)
  1006. self.normalize_before = normalize_before
  1007. if positionwise_layer_type == "linear":
  1008. positionwise_layer = PositionwiseFeedForward
  1009. positionwise_layer_args = (
  1010. output_size,
  1011. linear_units,
  1012. dropout_rate,
  1013. )
  1014. elif positionwise_layer_type == "conv1d":
  1015. positionwise_layer = MultiLayeredConv1d
  1016. positionwise_layer_args = (
  1017. output_size,
  1018. linear_units,
  1019. positionwise_conv_kernel_size,
  1020. dropout_rate,
  1021. )
  1022. elif positionwise_layer_type == "conv1d-linear":
  1023. positionwise_layer = Conv1dLinear
  1024. positionwise_layer_args = (
  1025. output_size,
  1026. linear_units,
  1027. positionwise_conv_kernel_size,
  1028. dropout_rate,
  1029. )
  1030. else:
  1031. raise NotImplementedError("Support only linear or conv1d.")
  1032. if selfattention_layer_type == "selfattn":
  1033. encoder_selfattn_layer = MultiHeadedAttention
  1034. encoder_selfattn_layer_args = (
  1035. attention_heads,
  1036. output_size,
  1037. attention_dropout_rate,
  1038. )
  1039. elif selfattention_layer_type == "sanm":
  1040. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  1041. encoder_selfattn_layer_args0 = (
  1042. attention_heads,
  1043. input_size,
  1044. output_size,
  1045. attention_dropout_rate,
  1046. kernel_size,
  1047. sanm_shfit,
  1048. )
  1049. encoder_selfattn_layer_args = (
  1050. attention_heads,
  1051. output_size,
  1052. output_size,
  1053. attention_dropout_rate,
  1054. kernel_size,
  1055. sanm_shfit,
  1056. )
  1057. self.encoders0 = repeat(
  1058. 1,
  1059. lambda lnum: EncoderLayerSANM(
  1060. input_size,
  1061. output_size,
  1062. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  1063. positionwise_layer(*positionwise_layer_args),
  1064. dropout_rate,
  1065. normalize_before,
  1066. concat_after,
  1067. ),
  1068. )
  1069. self.encoders = repeat(
  1070. num_blocks-1,
  1071. lambda lnum: EncoderLayerSANM(
  1072. output_size,
  1073. output_size,
  1074. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  1075. positionwise_layer(*positionwise_layer_args),
  1076. dropout_rate,
  1077. normalize_before,
  1078. concat_after,
  1079. ),
  1080. )
  1081. if self.normalize_before:
  1082. self.after_norm = LayerNorm(output_size)
  1083. self.interctc_layer_idx = interctc_layer_idx
  1084. if len(interctc_layer_idx) > 0:
  1085. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  1086. self.interctc_use_conditioning = interctc_use_conditioning
  1087. self.conditioning_layer = None
  1088. self.dropout = nn.Dropout(dropout_rate)
  1089. def output_size(self) -> int:
  1090. return self._output_size
  1091. def forward(
  1092. self,
  1093. xs_pad: torch.Tensor,
  1094. ilens: torch.Tensor,
  1095. vad_indexes: torch.Tensor,
  1096. prev_states: torch.Tensor = None,
  1097. ctc: CTC = None,
  1098. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  1099. """Embed positions in tensor.
  1100. Args:
  1101. xs_pad: input tensor (B, L, D)
  1102. ilens: input length (B)
  1103. prev_states: Not to be used now.
  1104. Returns:
  1105. position embedded tensor and mask
  1106. """
  1107. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  1108. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  1109. no_future_masks = masks & sub_masks
  1110. xs_pad *= self.output_size()**0.5
  1111. if self.embed is None:
  1112. xs_pad = xs_pad
  1113. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  1114. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  1115. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  1116. if short_status:
  1117. raise TooShortUttError(
  1118. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  1119. f"(it needs more than {limit_size} frames), return empty results",
  1120. xs_pad.size(1),
  1121. limit_size,
  1122. )
  1123. xs_pad, masks = self.embed(xs_pad, masks)
  1124. else:
  1125. xs_pad = self.embed(xs_pad)
  1126. # xs_pad = self.dropout(xs_pad)
  1127. mask_tup0 = [masks, no_future_masks]
  1128. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  1129. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  1130. intermediate_outs = []
  1131. for layer_idx, encoder_layer in enumerate(self.encoders):
  1132. if layer_idx + 1 == len(self.encoders):
  1133. # This is last layer.
  1134. coner_mask = torch.ones(masks.size(0),
  1135. masks.size(-1),
  1136. masks.size(-1),
  1137. device=xs_pad.device,
  1138. dtype=torch.bool)
  1139. for word_index, length in enumerate(ilens):
  1140. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  1141. vad_indexes[word_index],
  1142. device=xs_pad.device)
  1143. layer_mask = masks & coner_mask
  1144. else:
  1145. layer_mask = no_future_masks
  1146. mask_tup1 = [masks, layer_mask]
  1147. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  1148. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  1149. if self.normalize_before:
  1150. xs_pad = self.after_norm(xs_pad)
  1151. olens = masks.squeeze(1).sum(1)
  1152. if len(intermediate_outs) > 0:
  1153. return (xs_pad, intermediate_outs), olens, None
  1154. return xs_pad, olens, None