sanm_encoder.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  11. import numpy as np
  12. from funasr.torch_utils.device_funcs import to_device
  13. from funasr.modules.nets_utils import make_pad_mask
  14. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  15. from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.multi_layer_conv import Conv1dLinear
  18. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  19. from funasr.modules.positionwise_feed_forward import (
  20. PositionwiseFeedForward, # noqa: H301
  21. )
  22. from funasr.modules.repeat import repeat
  23. from funasr.modules.subsampling import Conv2dSubsampling
  24. from funasr.modules.subsampling import Conv2dSubsampling2
  25. from funasr.modules.subsampling import Conv2dSubsampling6
  26. from funasr.modules.subsampling import Conv2dSubsampling8
  27. from funasr.modules.subsampling import TooShortUttError
  28. from funasr.modules.subsampling import check_short_utt
  29. from funasr.modules.mask import subsequent_mask, vad_mask
  30. from funasr.models.ctc import CTC
  31. from funasr.models.encoder.abs_encoder import AbsEncoder
  32. class EncoderLayerSANM(nn.Module):
  33. def __init__(
  34. self,
  35. in_size,
  36. size,
  37. self_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. stochastic_depth_rate=0.0,
  43. ):
  44. """Construct an EncoderLayer object."""
  45. super(EncoderLayerSANM, self).__init__()
  46. self.self_attn = self_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(in_size)
  49. self.norm2 = LayerNorm(size)
  50. self.dropout = nn.Dropout(dropout_rate)
  51. self.in_size = in_size
  52. self.size = size
  53. self.normalize_before = normalize_before
  54. self.concat_after = concat_after
  55. if self.concat_after:
  56. self.concat_linear = nn.Linear(size + size, size)
  57. self.stochastic_depth_rate = stochastic_depth_rate
  58. self.dropout_rate = dropout_rate
  59. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  60. """Compute encoded features.
  61. Args:
  62. x_input (torch.Tensor): Input tensor (#batch, time, size).
  63. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  64. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  65. Returns:
  66. torch.Tensor: Output tensor (#batch, time, size).
  67. torch.Tensor: Mask tensor (#batch, time).
  68. """
  69. skip_layer = False
  70. # with stochastic depth, residual connection `x + f(x)` becomes
  71. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  72. stoch_layer_coeff = 1.0
  73. if self.training and self.stochastic_depth_rate > 0:
  74. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  75. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  76. if skip_layer:
  77. if cache is not None:
  78. x = torch.cat([cache, x], dim=1)
  79. return x, mask
  80. residual = x
  81. if self.normalize_before:
  82. x = self.norm1(x)
  83. if self.concat_after:
  84. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  85. if self.in_size == self.size:
  86. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. x = stoch_layer_coeff * self.concat_linear(x_concat)
  89. else:
  90. if self.in_size == self.size:
  91. x = residual + stoch_layer_coeff * self.dropout(
  92. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  93. )
  94. else:
  95. x = stoch_layer_coeff * self.dropout(
  96. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  97. )
  98. if not self.normalize_before:
  99. x = self.norm1(x)
  100. residual = x
  101. if self.normalize_before:
  102. x = self.norm2(x)
  103. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  104. if not self.normalize_before:
  105. x = self.norm2(x)
  106. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  107. class SANMEncoder(AbsEncoder):
  108. """
  109. Author: Speech Lab of DAMO Academy, Alibaba Group
  110. San-m: Memory equipped self-attention for end-to-end speech recognition
  111. https://arxiv.org/abs/2006.01713
  112. """
  113. def __init__(
  114. self,
  115. input_size: int,
  116. output_size: int = 256,
  117. attention_heads: int = 4,
  118. linear_units: int = 2048,
  119. num_blocks: int = 6,
  120. dropout_rate: float = 0.1,
  121. positional_dropout_rate: float = 0.1,
  122. attention_dropout_rate: float = 0.0,
  123. input_layer: Optional[str] = "conv2d",
  124. pos_enc_class=SinusoidalPositionEncoder,
  125. normalize_before: bool = True,
  126. concat_after: bool = False,
  127. positionwise_layer_type: str = "linear",
  128. positionwise_conv_kernel_size: int = 1,
  129. padding_idx: int = -1,
  130. interctc_layer_idx: List[int] = [],
  131. interctc_use_conditioning: bool = False,
  132. kernel_size : int = 11,
  133. sanm_shfit : int = 0,
  134. lora_list: List[str] = None,
  135. lora_rank: int = 8,
  136. lora_alpha: int = 16,
  137. lora_dropout: float = 0.1,
  138. selfattention_layer_type: str = "sanm",
  139. tf2torch_tensor_name_prefix_torch: str = "encoder",
  140. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  141. ):
  142. super().__init__()
  143. self._output_size = output_size
  144. if input_layer == "linear":
  145. self.embed = torch.nn.Sequential(
  146. torch.nn.Linear(input_size, output_size),
  147. torch.nn.LayerNorm(output_size),
  148. torch.nn.Dropout(dropout_rate),
  149. torch.nn.ReLU(),
  150. pos_enc_class(output_size, positional_dropout_rate),
  151. )
  152. elif input_layer == "conv2d":
  153. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  154. elif input_layer == "conv2d2":
  155. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  156. elif input_layer == "conv2d6":
  157. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  158. elif input_layer == "conv2d8":
  159. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  160. elif input_layer == "embed":
  161. self.embed = torch.nn.Sequential(
  162. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  163. SinusoidalPositionEncoder(),
  164. )
  165. elif input_layer is None:
  166. if input_size == output_size:
  167. self.embed = None
  168. else:
  169. self.embed = torch.nn.Linear(input_size, output_size)
  170. elif input_layer == "pe":
  171. self.embed = SinusoidalPositionEncoder()
  172. elif input_layer == "pe_online":
  173. self.embed = StreamSinusoidalPositionEncoder()
  174. else:
  175. raise ValueError("unknown input_layer: " + input_layer)
  176. self.normalize_before = normalize_before
  177. if positionwise_layer_type == "linear":
  178. positionwise_layer = PositionwiseFeedForward
  179. positionwise_layer_args = (
  180. output_size,
  181. linear_units,
  182. dropout_rate,
  183. )
  184. elif positionwise_layer_type == "conv1d":
  185. positionwise_layer = MultiLayeredConv1d
  186. positionwise_layer_args = (
  187. output_size,
  188. linear_units,
  189. positionwise_conv_kernel_size,
  190. dropout_rate,
  191. )
  192. elif positionwise_layer_type == "conv1d-linear":
  193. positionwise_layer = Conv1dLinear
  194. positionwise_layer_args = (
  195. output_size,
  196. linear_units,
  197. positionwise_conv_kernel_size,
  198. dropout_rate,
  199. )
  200. else:
  201. raise NotImplementedError("Support only linear or conv1d.")
  202. if selfattention_layer_type == "selfattn":
  203. encoder_selfattn_layer = MultiHeadedAttention
  204. encoder_selfattn_layer_args = (
  205. attention_heads,
  206. output_size,
  207. attention_dropout_rate,
  208. )
  209. elif selfattention_layer_type == "sanm":
  210. encoder_selfattn_layer = MultiHeadedAttentionSANM
  211. encoder_selfattn_layer_args0 = (
  212. attention_heads,
  213. input_size,
  214. output_size,
  215. attention_dropout_rate,
  216. kernel_size,
  217. sanm_shfit,
  218. lora_list,
  219. lora_rank,
  220. lora_alpha,
  221. lora_dropout,
  222. )
  223. encoder_selfattn_layer_args = (
  224. attention_heads,
  225. output_size,
  226. output_size,
  227. attention_dropout_rate,
  228. kernel_size,
  229. sanm_shfit,
  230. lora_list,
  231. lora_rank,
  232. lora_alpha,
  233. lora_dropout,
  234. )
  235. self.encoders0 = repeat(
  236. 1,
  237. lambda lnum: EncoderLayerSANM(
  238. input_size,
  239. output_size,
  240. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  241. positionwise_layer(*positionwise_layer_args),
  242. dropout_rate,
  243. normalize_before,
  244. concat_after,
  245. ),
  246. )
  247. self.encoders = repeat(
  248. num_blocks-1,
  249. lambda lnum: EncoderLayerSANM(
  250. output_size,
  251. output_size,
  252. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  253. positionwise_layer(*positionwise_layer_args),
  254. dropout_rate,
  255. normalize_before,
  256. concat_after,
  257. ),
  258. )
  259. if self.normalize_before:
  260. self.after_norm = LayerNorm(output_size)
  261. self.interctc_layer_idx = interctc_layer_idx
  262. if len(interctc_layer_idx) > 0:
  263. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  264. self.interctc_use_conditioning = interctc_use_conditioning
  265. self.conditioning_layer = None
  266. self.dropout = nn.Dropout(dropout_rate)
  267. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  268. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  269. def output_size(self) -> int:
  270. return self._output_size
  271. def forward(
  272. self,
  273. xs_pad: torch.Tensor,
  274. ilens: torch.Tensor,
  275. prev_states: torch.Tensor = None,
  276. ctc: CTC = None,
  277. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  278. """Embed positions in tensor.
  279. Args:
  280. xs_pad: input tensor (B, L, D)
  281. ilens: input length (B)
  282. prev_states: Not to be used now.
  283. Returns:
  284. position embedded tensor and mask
  285. """
  286. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  287. xs_pad = xs_pad * self.output_size()**0.5
  288. if self.embed is None:
  289. xs_pad = xs_pad
  290. elif (
  291. isinstance(self.embed, Conv2dSubsampling)
  292. or isinstance(self.embed, Conv2dSubsampling2)
  293. or isinstance(self.embed, Conv2dSubsampling6)
  294. or isinstance(self.embed, Conv2dSubsampling8)
  295. ):
  296. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  297. if short_status:
  298. raise TooShortUttError(
  299. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  300. + f"(it needs more than {limit_size} frames), return empty results",
  301. xs_pad.size(1),
  302. limit_size,
  303. )
  304. xs_pad, masks = self.embed(xs_pad, masks)
  305. else:
  306. xs_pad = self.embed(xs_pad)
  307. # xs_pad = self.dropout(xs_pad)
  308. encoder_outs = self.encoders0(xs_pad, masks)
  309. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  310. intermediate_outs = []
  311. if len(self.interctc_layer_idx) == 0:
  312. encoder_outs = self.encoders(xs_pad, masks)
  313. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  314. else:
  315. for layer_idx, encoder_layer in enumerate(self.encoders):
  316. encoder_outs = encoder_layer(xs_pad, masks)
  317. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  318. if layer_idx + 1 in self.interctc_layer_idx:
  319. encoder_out = xs_pad
  320. # intermediate outputs are also normalized
  321. if self.normalize_before:
  322. encoder_out = self.after_norm(encoder_out)
  323. intermediate_outs.append((layer_idx + 1, encoder_out))
  324. if self.interctc_use_conditioning:
  325. ctc_out = ctc.softmax(encoder_out)
  326. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  327. if self.normalize_before:
  328. xs_pad = self.after_norm(xs_pad)
  329. olens = masks.squeeze(1).sum(1)
  330. if len(intermediate_outs) > 0:
  331. return (xs_pad, intermediate_outs), olens, None
  332. return xs_pad, olens, None
  333. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  334. if len(cache) == 0:
  335. return feats
  336. cache["feats"] = to_device(cache["feats"], device=feats.device)
  337. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  338. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  339. return overlap_feats
  340. def forward_chunk(self,
  341. xs_pad: torch.Tensor,
  342. ilens: torch.Tensor,
  343. cache: dict = None,
  344. ctc: CTC = None,
  345. ):
  346. xs_pad *= self.output_size() ** 0.5
  347. if self.embed is None:
  348. xs_pad = xs_pad
  349. else:
  350. xs_pad = self.embed(xs_pad, cache)
  351. if cache["tail_chunk"]:
  352. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  353. else:
  354. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  355. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  356. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  357. intermediate_outs = []
  358. if len(self.interctc_layer_idx) == 0:
  359. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  360. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  361. else:
  362. for layer_idx, encoder_layer in enumerate(self.encoders):
  363. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  364. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  365. if layer_idx + 1 in self.interctc_layer_idx:
  366. encoder_out = xs_pad
  367. # intermediate outputs are also normalized
  368. if self.normalize_before:
  369. encoder_out = self.after_norm(encoder_out)
  370. intermediate_outs.append((layer_idx + 1, encoder_out))
  371. if self.interctc_use_conditioning:
  372. ctc_out = ctc.softmax(encoder_out)
  373. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  374. if self.normalize_before:
  375. xs_pad = self.after_norm(xs_pad)
  376. if len(intermediate_outs) > 0:
  377. return (xs_pad, intermediate_outs), None, None
  378. return xs_pad, ilens, None
  379. def gen_tf2torch_map_dict(self):
  380. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  381. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  382. map_dict_local = {
  383. ## encoder
  384. # cicd
  385. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  386. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  387. "squeeze": None,
  388. "transpose": None,
  389. }, # (256,),(256,)
  390. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  391. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  392. "squeeze": None,
  393. "transpose": None,
  394. }, # (256,),(256,)
  395. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  396. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  397. "squeeze": 0,
  398. "transpose": (1, 0),
  399. }, # (768,256),(1,256,768)
  400. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  401. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  402. "squeeze": None,
  403. "transpose": None,
  404. }, # (768,),(768,)
  405. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  406. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  407. "squeeze": 0,
  408. "transpose": (1, 2, 0),
  409. }, # (256,1,31),(1,31,256,1)
  410. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  411. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  412. "squeeze": 0,
  413. "transpose": (1, 0),
  414. }, # (256,256),(1,256,256)
  415. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  416. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  417. "squeeze": None,
  418. "transpose": None,
  419. }, # (256,),(256,)
  420. # ffn
  421. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  422. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  423. "squeeze": None,
  424. "transpose": None,
  425. }, # (256,),(256,)
  426. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  427. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  428. "squeeze": None,
  429. "transpose": None,
  430. }, # (256,),(256,)
  431. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  432. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  433. "squeeze": 0,
  434. "transpose": (1, 0),
  435. }, # (1024,256),(1,256,1024)
  436. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  437. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  438. "squeeze": None,
  439. "transpose": None,
  440. }, # (1024,),(1024,)
  441. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  442. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  443. "squeeze": 0,
  444. "transpose": (1, 0),
  445. }, # (256,1024),(1,1024,256)
  446. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  447. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  448. "squeeze": None,
  449. "transpose": None,
  450. }, # (256,),(256,)
  451. # out norm
  452. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  453. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  454. "squeeze": None,
  455. "transpose": None,
  456. }, # (256,),(256,)
  457. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  458. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  459. "squeeze": None,
  460. "transpose": None,
  461. }, # (256,),(256,)
  462. }
  463. return map_dict_local
  464. def convert_tf2torch(self,
  465. var_dict_tf,
  466. var_dict_torch,
  467. ):
  468. map_dict = self.gen_tf2torch_map_dict()
  469. var_dict_torch_update = dict()
  470. for name in sorted(var_dict_torch.keys(), reverse=False):
  471. names = name.split('.')
  472. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  473. if names[1] == "encoders0":
  474. layeridx = int(names[2])
  475. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  476. name_q = name_q.replace("encoders0", "encoders")
  477. layeridx_bias = 0
  478. layeridx += layeridx_bias
  479. if name_q in map_dict.keys():
  480. name_v = map_dict[name_q]["name"]
  481. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  482. data_tf = var_dict_tf[name_tf]
  483. if map_dict[name_q]["squeeze"] is not None:
  484. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  485. if map_dict[name_q]["transpose"] is not None:
  486. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  487. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  488. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  489. var_dict_torch[
  490. name].size(),
  491. data_tf.size())
  492. var_dict_torch_update[name] = data_tf
  493. logging.info(
  494. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  495. var_dict_tf[name_tf].shape))
  496. elif names[1] == "encoders":
  497. layeridx = int(names[2])
  498. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  499. layeridx_bias = 1
  500. layeridx += layeridx_bias
  501. if name_q in map_dict.keys():
  502. name_v = map_dict[name_q]["name"]
  503. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  504. data_tf = var_dict_tf[name_tf]
  505. if map_dict[name_q]["squeeze"] is not None:
  506. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  507. if map_dict[name_q]["transpose"] is not None:
  508. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  509. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  510. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  511. var_dict_torch[
  512. name].size(),
  513. data_tf.size())
  514. var_dict_torch_update[name] = data_tf
  515. logging.info(
  516. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  517. var_dict_tf[name_tf].shape))
  518. elif names[1] == "after_norm":
  519. name_tf = map_dict[name]["name"]
  520. data_tf = var_dict_tf[name_tf]
  521. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  522. var_dict_torch_update[name] = data_tf
  523. logging.info(
  524. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  525. var_dict_tf[name_tf].shape))
  526. return var_dict_torch_update
  527. class SANMEncoderChunkOpt(AbsEncoder):
  528. """
  529. Author: Speech Lab of DAMO Academy, Alibaba Group
  530. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  531. https://arxiv.org/abs/2006.01713
  532. """
  533. def __init__(
  534. self,
  535. input_size: int,
  536. output_size: int = 256,
  537. attention_heads: int = 4,
  538. linear_units: int = 2048,
  539. num_blocks: int = 6,
  540. dropout_rate: float = 0.1,
  541. positional_dropout_rate: float = 0.1,
  542. attention_dropout_rate: float = 0.0,
  543. input_layer: Optional[str] = "conv2d",
  544. pos_enc_class=SinusoidalPositionEncoder,
  545. normalize_before: bool = True,
  546. concat_after: bool = False,
  547. positionwise_layer_type: str = "linear",
  548. positionwise_conv_kernel_size: int = 1,
  549. padding_idx: int = -1,
  550. interctc_layer_idx: List[int] = [],
  551. interctc_use_conditioning: bool = False,
  552. kernel_size: int = 11,
  553. sanm_shfit: int = 0,
  554. selfattention_layer_type: str = "sanm",
  555. chunk_size: Union[int, Sequence[int]] = (16,),
  556. stride: Union[int, Sequence[int]] = (10,),
  557. pad_left: Union[int, Sequence[int]] = (0,),
  558. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  559. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  560. tf2torch_tensor_name_prefix_torch: str = "encoder",
  561. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  562. ):
  563. super().__init__()
  564. self._output_size = output_size
  565. if input_layer == "linear":
  566. self.embed = torch.nn.Sequential(
  567. torch.nn.Linear(input_size, output_size),
  568. torch.nn.LayerNorm(output_size),
  569. torch.nn.Dropout(dropout_rate),
  570. torch.nn.ReLU(),
  571. pos_enc_class(output_size, positional_dropout_rate),
  572. )
  573. elif input_layer == "conv2d":
  574. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  575. elif input_layer == "conv2d2":
  576. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  577. elif input_layer == "conv2d6":
  578. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  579. elif input_layer == "conv2d8":
  580. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  581. elif input_layer == "embed":
  582. self.embed = torch.nn.Sequential(
  583. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  584. pos_enc_class(output_size, positional_dropout_rate),
  585. )
  586. elif input_layer is None:
  587. if input_size == output_size:
  588. self.embed = None
  589. else:
  590. self.embed = torch.nn.Linear(input_size, output_size)
  591. elif input_layer == "pe":
  592. self.embed = SinusoidalPositionEncoder()
  593. elif input_layer == "pe_online":
  594. self.embed = StreamSinusoidalPositionEncoder()
  595. else:
  596. raise ValueError("unknown input_layer: " + input_layer)
  597. self.normalize_before = normalize_before
  598. if positionwise_layer_type == "linear":
  599. positionwise_layer = PositionwiseFeedForward
  600. positionwise_layer_args = (
  601. output_size,
  602. linear_units,
  603. dropout_rate,
  604. )
  605. elif positionwise_layer_type == "conv1d":
  606. positionwise_layer = MultiLayeredConv1d
  607. positionwise_layer_args = (
  608. output_size,
  609. linear_units,
  610. positionwise_conv_kernel_size,
  611. dropout_rate,
  612. )
  613. elif positionwise_layer_type == "conv1d-linear":
  614. positionwise_layer = Conv1dLinear
  615. positionwise_layer_args = (
  616. output_size,
  617. linear_units,
  618. positionwise_conv_kernel_size,
  619. dropout_rate,
  620. )
  621. else:
  622. raise NotImplementedError("Support only linear or conv1d.")
  623. if selfattention_layer_type == "selfattn":
  624. encoder_selfattn_layer = MultiHeadedAttention
  625. encoder_selfattn_layer_args = (
  626. attention_heads,
  627. output_size,
  628. attention_dropout_rate,
  629. )
  630. elif selfattention_layer_type == "sanm":
  631. encoder_selfattn_layer = MultiHeadedAttentionSANM
  632. encoder_selfattn_layer_args0 = (
  633. attention_heads,
  634. input_size,
  635. output_size,
  636. attention_dropout_rate,
  637. kernel_size,
  638. sanm_shfit,
  639. )
  640. encoder_selfattn_layer_args = (
  641. attention_heads,
  642. output_size,
  643. output_size,
  644. attention_dropout_rate,
  645. kernel_size,
  646. sanm_shfit,
  647. )
  648. self.encoders0 = repeat(
  649. 1,
  650. lambda lnum: EncoderLayerSANM(
  651. input_size,
  652. output_size,
  653. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  654. positionwise_layer(*positionwise_layer_args),
  655. dropout_rate,
  656. normalize_before,
  657. concat_after,
  658. ),
  659. )
  660. self.encoders = repeat(
  661. num_blocks - 1,
  662. lambda lnum: EncoderLayerSANM(
  663. output_size,
  664. output_size,
  665. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  666. positionwise_layer(*positionwise_layer_args),
  667. dropout_rate,
  668. normalize_before,
  669. concat_after,
  670. ),
  671. )
  672. if self.normalize_before:
  673. self.after_norm = LayerNorm(output_size)
  674. self.interctc_layer_idx = interctc_layer_idx
  675. if len(interctc_layer_idx) > 0:
  676. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  677. self.interctc_use_conditioning = interctc_use_conditioning
  678. self.conditioning_layer = None
  679. shfit_fsmn = (kernel_size - 1) // 2
  680. self.overlap_chunk_cls = overlap_chunk(
  681. chunk_size=chunk_size,
  682. stride=stride,
  683. pad_left=pad_left,
  684. shfit_fsmn=shfit_fsmn,
  685. encoder_att_look_back_factor=encoder_att_look_back_factor,
  686. decoder_att_look_back_factor=decoder_att_look_back_factor,
  687. )
  688. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  689. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  690. def output_size(self) -> int:
  691. return self._output_size
  692. def forward(
  693. self,
  694. xs_pad: torch.Tensor,
  695. ilens: torch.Tensor,
  696. prev_states: torch.Tensor = None,
  697. ctc: CTC = None,
  698. ind: int = 0,
  699. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  700. """Embed positions in tensor.
  701. Args:
  702. xs_pad: input tensor (B, L, D)
  703. ilens: input length (B)
  704. prev_states: Not to be used now.
  705. Returns:
  706. position embedded tensor and mask
  707. """
  708. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  709. xs_pad *= self.output_size() ** 0.5
  710. if self.embed is None:
  711. xs_pad = xs_pad
  712. elif (
  713. isinstance(self.embed, Conv2dSubsampling)
  714. or isinstance(self.embed, Conv2dSubsampling2)
  715. or isinstance(self.embed, Conv2dSubsampling6)
  716. or isinstance(self.embed, Conv2dSubsampling8)
  717. ):
  718. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  719. if short_status:
  720. raise TooShortUttError(
  721. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  722. + f"(it needs more than {limit_size} frames), return empty results",
  723. xs_pad.size(1),
  724. limit_size,
  725. )
  726. xs_pad, masks = self.embed(xs_pad, masks)
  727. else:
  728. xs_pad = self.embed(xs_pad)
  729. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  730. if self.overlap_chunk_cls is not None:
  731. ilens = masks.squeeze(1).sum(1)
  732. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  733. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  734. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  735. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  736. dtype=xs_pad.dtype)
  737. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  738. xs_pad.size(0),
  739. dtype=xs_pad.dtype)
  740. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  741. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  742. intermediate_outs = []
  743. if len(self.interctc_layer_idx) == 0:
  744. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  745. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  746. else:
  747. for layer_idx, encoder_layer in enumerate(self.encoders):
  748. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  749. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  750. if layer_idx + 1 in self.interctc_layer_idx:
  751. encoder_out = xs_pad
  752. # intermediate outputs are also normalized
  753. if self.normalize_before:
  754. encoder_out = self.after_norm(encoder_out)
  755. intermediate_outs.append((layer_idx + 1, encoder_out))
  756. if self.interctc_use_conditioning:
  757. ctc_out = ctc.softmax(encoder_out)
  758. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  759. if self.normalize_before:
  760. xs_pad = self.after_norm(xs_pad)
  761. olens = masks.squeeze(1).sum(1)
  762. if len(intermediate_outs) > 0:
  763. return (xs_pad, intermediate_outs), olens, None
  764. return xs_pad, olens, None
  765. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  766. if len(cache) == 0:
  767. return feats
  768. cache["feats"] = to_device(cache["feats"], device=feats.device)
  769. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  770. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  771. return overlap_feats
  772. def forward_chunk(self,
  773. xs_pad: torch.Tensor,
  774. ilens: torch.Tensor,
  775. cache: dict = None,
  776. ctc: CTC = None,
  777. ):
  778. xs_pad *= self.output_size() ** 0.5
  779. if self.embed is None:
  780. xs_pad = xs_pad
  781. else:
  782. xs_pad = self.embed(xs_pad, cache)
  783. if cache["tail_chunk"]:
  784. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  785. else:
  786. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  787. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  788. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  789. intermediate_outs = []
  790. if len(self.interctc_layer_idx) == 0:
  791. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  792. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  793. else:
  794. for layer_idx, encoder_layer in enumerate(self.encoders):
  795. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  796. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  797. if layer_idx + 1 in self.interctc_layer_idx:
  798. encoder_out = xs_pad
  799. # intermediate outputs are also normalized
  800. if self.normalize_before:
  801. encoder_out = self.after_norm(encoder_out)
  802. intermediate_outs.append((layer_idx + 1, encoder_out))
  803. if self.interctc_use_conditioning:
  804. ctc_out = ctc.softmax(encoder_out)
  805. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  806. if self.normalize_before:
  807. xs_pad = self.after_norm(xs_pad)
  808. if len(intermediate_outs) > 0:
  809. return (xs_pad, intermediate_outs), None, None
  810. return xs_pad, ilens, None
  811. def gen_tf2torch_map_dict(self):
  812. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  813. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  814. map_dict_local = {
  815. ## encoder
  816. # cicd
  817. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  818. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  819. "squeeze": None,
  820. "transpose": None,
  821. }, # (256,),(256,)
  822. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  823. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  824. "squeeze": None,
  825. "transpose": None,
  826. }, # (256,),(256,)
  827. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  828. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  829. "squeeze": 0,
  830. "transpose": (1, 0),
  831. }, # (768,256),(1,256,768)
  832. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  833. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  834. "squeeze": None,
  835. "transpose": None,
  836. }, # (768,),(768,)
  837. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  838. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  839. "squeeze": 0,
  840. "transpose": (1, 2, 0),
  841. }, # (256,1,31),(1,31,256,1)
  842. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  843. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  844. "squeeze": 0,
  845. "transpose": (1, 0),
  846. }, # (256,256),(1,256,256)
  847. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  848. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  849. "squeeze": None,
  850. "transpose": None,
  851. }, # (256,),(256,)
  852. # ffn
  853. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  854. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  855. "squeeze": None,
  856. "transpose": None,
  857. }, # (256,),(256,)
  858. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  859. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  860. "squeeze": None,
  861. "transpose": None,
  862. }, # (256,),(256,)
  863. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  864. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  865. "squeeze": 0,
  866. "transpose": (1, 0),
  867. }, # (1024,256),(1,256,1024)
  868. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  869. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  870. "squeeze": None,
  871. "transpose": None,
  872. }, # (1024,),(1024,)
  873. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  874. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  875. "squeeze": 0,
  876. "transpose": (1, 0),
  877. }, # (256,1024),(1,1024,256)
  878. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  879. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  880. "squeeze": None,
  881. "transpose": None,
  882. }, # (256,),(256,)
  883. # out norm
  884. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  885. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  886. "squeeze": None,
  887. "transpose": None,
  888. }, # (256,),(256,)
  889. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  890. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  891. "squeeze": None,
  892. "transpose": None,
  893. }, # (256,),(256,)
  894. }
  895. return map_dict_local
  896. def convert_tf2torch(self,
  897. var_dict_tf,
  898. var_dict_torch,
  899. ):
  900. map_dict = self.gen_tf2torch_map_dict()
  901. var_dict_torch_update = dict()
  902. for name in sorted(var_dict_torch.keys(), reverse=False):
  903. names = name.split('.')
  904. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  905. if names[1] == "encoders0":
  906. layeridx = int(names[2])
  907. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  908. name_q = name_q.replace("encoders0", "encoders")
  909. layeridx_bias = 0
  910. layeridx += layeridx_bias
  911. if name_q in map_dict.keys():
  912. name_v = map_dict[name_q]["name"]
  913. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  914. data_tf = var_dict_tf[name_tf]
  915. if map_dict[name_q]["squeeze"] is not None:
  916. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  917. if map_dict[name_q]["transpose"] is not None:
  918. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  919. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  920. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  921. var_dict_torch[
  922. name].size(),
  923. data_tf.size())
  924. var_dict_torch_update[name] = data_tf
  925. logging.info(
  926. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  927. var_dict_tf[name_tf].shape))
  928. elif names[1] == "encoders":
  929. layeridx = int(names[2])
  930. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  931. layeridx_bias = 1
  932. layeridx += layeridx_bias
  933. if name_q in map_dict.keys():
  934. name_v = map_dict[name_q]["name"]
  935. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  936. data_tf = var_dict_tf[name_tf]
  937. if map_dict[name_q]["squeeze"] is not None:
  938. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  939. if map_dict[name_q]["transpose"] is not None:
  940. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  941. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  942. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  943. var_dict_torch[
  944. name].size(),
  945. data_tf.size())
  946. var_dict_torch_update[name] = data_tf
  947. logging.info(
  948. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  949. var_dict_tf[name_tf].shape))
  950. elif names[1] == "after_norm":
  951. name_tf = map_dict[name]["name"]
  952. data_tf = var_dict_tf[name_tf]
  953. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  954. var_dict_torch_update[name] = data_tf
  955. logging.info(
  956. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  957. var_dict_tf[name_tf].shape))
  958. return var_dict_torch_update
  959. class SANMVadEncoder(AbsEncoder):
  960. """
  961. Author: Speech Lab of DAMO Academy, Alibaba Group
  962. """
  963. def __init__(
  964. self,
  965. input_size: int,
  966. output_size: int = 256,
  967. attention_heads: int = 4,
  968. linear_units: int = 2048,
  969. num_blocks: int = 6,
  970. dropout_rate: float = 0.1,
  971. positional_dropout_rate: float = 0.1,
  972. attention_dropout_rate: float = 0.0,
  973. input_layer: Optional[str] = "conv2d",
  974. pos_enc_class=SinusoidalPositionEncoder,
  975. normalize_before: bool = True,
  976. concat_after: bool = False,
  977. positionwise_layer_type: str = "linear",
  978. positionwise_conv_kernel_size: int = 1,
  979. padding_idx: int = -1,
  980. interctc_layer_idx: List[int] = [],
  981. interctc_use_conditioning: bool = False,
  982. kernel_size : int = 11,
  983. sanm_shfit : int = 0,
  984. selfattention_layer_type: str = "sanm",
  985. ):
  986. super().__init__()
  987. self._output_size = output_size
  988. if input_layer == "linear":
  989. self.embed = torch.nn.Sequential(
  990. torch.nn.Linear(input_size, output_size),
  991. torch.nn.LayerNorm(output_size),
  992. torch.nn.Dropout(dropout_rate),
  993. torch.nn.ReLU(),
  994. pos_enc_class(output_size, positional_dropout_rate),
  995. )
  996. elif input_layer == "conv2d":
  997. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  998. elif input_layer == "conv2d2":
  999. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  1000. elif input_layer == "conv2d6":
  1001. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  1002. elif input_layer == "conv2d8":
  1003. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  1004. elif input_layer == "embed":
  1005. self.embed = torch.nn.Sequential(
  1006. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  1007. SinusoidalPositionEncoder(),
  1008. )
  1009. elif input_layer is None:
  1010. if input_size == output_size:
  1011. self.embed = None
  1012. else:
  1013. self.embed = torch.nn.Linear(input_size, output_size)
  1014. elif input_layer == "pe":
  1015. self.embed = SinusoidalPositionEncoder()
  1016. else:
  1017. raise ValueError("unknown input_layer: " + input_layer)
  1018. self.normalize_before = normalize_before
  1019. if positionwise_layer_type == "linear":
  1020. positionwise_layer = PositionwiseFeedForward
  1021. positionwise_layer_args = (
  1022. output_size,
  1023. linear_units,
  1024. dropout_rate,
  1025. )
  1026. elif positionwise_layer_type == "conv1d":
  1027. positionwise_layer = MultiLayeredConv1d
  1028. positionwise_layer_args = (
  1029. output_size,
  1030. linear_units,
  1031. positionwise_conv_kernel_size,
  1032. dropout_rate,
  1033. )
  1034. elif positionwise_layer_type == "conv1d-linear":
  1035. positionwise_layer = Conv1dLinear
  1036. positionwise_layer_args = (
  1037. output_size,
  1038. linear_units,
  1039. positionwise_conv_kernel_size,
  1040. dropout_rate,
  1041. )
  1042. else:
  1043. raise NotImplementedError("Support only linear or conv1d.")
  1044. if selfattention_layer_type == "selfattn":
  1045. encoder_selfattn_layer = MultiHeadedAttention
  1046. encoder_selfattn_layer_args = (
  1047. attention_heads,
  1048. output_size,
  1049. attention_dropout_rate,
  1050. )
  1051. elif selfattention_layer_type == "sanm":
  1052. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  1053. encoder_selfattn_layer_args0 = (
  1054. attention_heads,
  1055. input_size,
  1056. output_size,
  1057. attention_dropout_rate,
  1058. kernel_size,
  1059. sanm_shfit,
  1060. )
  1061. encoder_selfattn_layer_args = (
  1062. attention_heads,
  1063. output_size,
  1064. output_size,
  1065. attention_dropout_rate,
  1066. kernel_size,
  1067. sanm_shfit,
  1068. )
  1069. self.encoders0 = repeat(
  1070. 1,
  1071. lambda lnum: EncoderLayerSANM(
  1072. input_size,
  1073. output_size,
  1074. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  1075. positionwise_layer(*positionwise_layer_args),
  1076. dropout_rate,
  1077. normalize_before,
  1078. concat_after,
  1079. ),
  1080. )
  1081. self.encoders = repeat(
  1082. num_blocks-1,
  1083. lambda lnum: EncoderLayerSANM(
  1084. output_size,
  1085. output_size,
  1086. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  1087. positionwise_layer(*positionwise_layer_args),
  1088. dropout_rate,
  1089. normalize_before,
  1090. concat_after,
  1091. ),
  1092. )
  1093. if self.normalize_before:
  1094. self.after_norm = LayerNorm(output_size)
  1095. self.interctc_layer_idx = interctc_layer_idx
  1096. if len(interctc_layer_idx) > 0:
  1097. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  1098. self.interctc_use_conditioning = interctc_use_conditioning
  1099. self.conditioning_layer = None
  1100. self.dropout = nn.Dropout(dropout_rate)
  1101. def output_size(self) -> int:
  1102. return self._output_size
  1103. def forward(
  1104. self,
  1105. xs_pad: torch.Tensor,
  1106. ilens: torch.Tensor,
  1107. vad_indexes: torch.Tensor,
  1108. prev_states: torch.Tensor = None,
  1109. ctc: CTC = None,
  1110. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  1111. """Embed positions in tensor.
  1112. Args:
  1113. xs_pad: input tensor (B, L, D)
  1114. ilens: input length (B)
  1115. prev_states: Not to be used now.
  1116. Returns:
  1117. position embedded tensor and mask
  1118. """
  1119. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  1120. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  1121. no_future_masks = masks & sub_masks
  1122. xs_pad *= self.output_size()**0.5
  1123. if self.embed is None:
  1124. xs_pad = xs_pad
  1125. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  1126. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  1127. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  1128. if short_status:
  1129. raise TooShortUttError(
  1130. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  1131. f"(it needs more than {limit_size} frames), return empty results",
  1132. xs_pad.size(1),
  1133. limit_size,
  1134. )
  1135. xs_pad, masks = self.embed(xs_pad, masks)
  1136. else:
  1137. xs_pad = self.embed(xs_pad)
  1138. # xs_pad = self.dropout(xs_pad)
  1139. mask_tup0 = [masks, no_future_masks]
  1140. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  1141. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  1142. intermediate_outs = []
  1143. for layer_idx, encoder_layer in enumerate(self.encoders):
  1144. if layer_idx + 1 == len(self.encoders):
  1145. # This is last layer.
  1146. coner_mask = torch.ones(masks.size(0),
  1147. masks.size(-1),
  1148. masks.size(-1),
  1149. device=xs_pad.device,
  1150. dtype=torch.bool)
  1151. for word_index, length in enumerate(ilens):
  1152. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  1153. vad_indexes[word_index],
  1154. device=xs_pad.device)
  1155. layer_mask = masks & coner_mask
  1156. else:
  1157. layer_mask = no_future_masks
  1158. mask_tup1 = [masks, layer_mask]
  1159. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  1160. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  1161. if self.normalize_before:
  1162. xs_pad = self.after_norm(xs_pad)
  1163. olens = masks.squeeze(1).sum(1)
  1164. if len(intermediate_outs) > 0:
  1165. return (xs_pad, intermediate_outs), olens, None
  1166. return xs_pad, olens, None