sanm_encoder.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  11. from typeguard import check_argument_types
  12. import numpy as np
  13. from funasr.torch_utils.device_funcs import to_device
  14. from funasr.modules.nets_utils import make_pad_mask
  15. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  16. from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  17. from funasr.modules.layer_norm import LayerNorm
  18. from funasr.modules.multi_layer_conv import Conv1dLinear
  19. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  20. from funasr.modules.positionwise_feed_forward import (
  21. PositionwiseFeedForward, # noqa: H301
  22. )
  23. from funasr.modules.repeat import repeat
  24. from funasr.modules.subsampling import Conv2dSubsampling
  25. from funasr.modules.subsampling import Conv2dSubsampling2
  26. from funasr.modules.subsampling import Conv2dSubsampling6
  27. from funasr.modules.subsampling import Conv2dSubsampling8
  28. from funasr.modules.subsampling import TooShortUttError
  29. from funasr.modules.subsampling import check_short_utt
  30. from funasr.modules.mask import subsequent_mask, vad_mask
  31. from funasr.models.ctc import CTC
  32. from funasr.models.encoder.abs_encoder import AbsEncoder
  33. class EncoderLayerSANM(nn.Module):
  34. def __init__(
  35. self,
  36. in_size,
  37. size,
  38. self_attn,
  39. feed_forward,
  40. dropout_rate,
  41. normalize_before=True,
  42. concat_after=False,
  43. stochastic_depth_rate=0.0,
  44. ):
  45. """Construct an EncoderLayer object."""
  46. super(EncoderLayerSANM, self).__init__()
  47. self.self_attn = self_attn
  48. self.feed_forward = feed_forward
  49. self.norm1 = LayerNorm(in_size)
  50. self.norm2 = LayerNorm(size)
  51. self.dropout = nn.Dropout(dropout_rate)
  52. self.in_size = in_size
  53. self.size = size
  54. self.normalize_before = normalize_before
  55. self.concat_after = concat_after
  56. if self.concat_after:
  57. self.concat_linear = nn.Linear(size + size, size)
  58. self.stochastic_depth_rate = stochastic_depth_rate
  59. self.dropout_rate = dropout_rate
  60. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  61. """Compute encoded features.
  62. Args:
  63. x_input (torch.Tensor): Input tensor (#batch, time, size).
  64. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  65. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  66. Returns:
  67. torch.Tensor: Output tensor (#batch, time, size).
  68. torch.Tensor: Mask tensor (#batch, time).
  69. """
  70. skip_layer = False
  71. # with stochastic depth, residual connection `x + f(x)` becomes
  72. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  73. stoch_layer_coeff = 1.0
  74. if self.training and self.stochastic_depth_rate > 0:
  75. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  76. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  77. if skip_layer:
  78. if cache is not None:
  79. x = torch.cat([cache, x], dim=1)
  80. return x, mask
  81. residual = x
  82. if self.normalize_before:
  83. x = self.norm1(x)
  84. if self.concat_after:
  85. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  86. if self.in_size == self.size:
  87. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  88. else:
  89. x = stoch_layer_coeff * self.concat_linear(x_concat)
  90. else:
  91. if self.in_size == self.size:
  92. x = residual + stoch_layer_coeff * self.dropout(
  93. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  94. )
  95. else:
  96. x = stoch_layer_coeff * self.dropout(
  97. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  98. )
  99. if not self.normalize_before:
  100. x = self.norm1(x)
  101. residual = x
  102. if self.normalize_before:
  103. x = self.norm2(x)
  104. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  105. if not self.normalize_before:
  106. x = self.norm2(x)
  107. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  108. class SANMEncoder(AbsEncoder):
  109. """
  110. Author: Speech Lab of DAMO Academy, Alibaba Group
  111. San-m: Memory equipped self-attention for end-to-end speech recognition
  112. https://arxiv.org/abs/2006.01713
  113. """
  114. def __init__(
  115. self,
  116. input_size: int,
  117. output_size: int = 256,
  118. attention_heads: int = 4,
  119. linear_units: int = 2048,
  120. num_blocks: int = 6,
  121. dropout_rate: float = 0.1,
  122. positional_dropout_rate: float = 0.1,
  123. attention_dropout_rate: float = 0.0,
  124. input_layer: Optional[str] = "conv2d",
  125. pos_enc_class=SinusoidalPositionEncoder,
  126. normalize_before: bool = True,
  127. concat_after: bool = False,
  128. positionwise_layer_type: str = "linear",
  129. positionwise_conv_kernel_size: int = 1,
  130. padding_idx: int = -1,
  131. interctc_layer_idx: List[int] = [],
  132. interctc_use_conditioning: bool = False,
  133. kernel_size : int = 11,
  134. sanm_shfit : int = 0,
  135. selfattention_layer_type: str = "sanm",
  136. tf2torch_tensor_name_prefix_torch: str = "encoder",
  137. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  138. ):
  139. assert check_argument_types()
  140. super().__init__()
  141. self._output_size = output_size
  142. if input_layer == "linear":
  143. self.embed = torch.nn.Sequential(
  144. torch.nn.Linear(input_size, output_size),
  145. torch.nn.LayerNorm(output_size),
  146. torch.nn.Dropout(dropout_rate),
  147. torch.nn.ReLU(),
  148. pos_enc_class(output_size, positional_dropout_rate),
  149. )
  150. elif input_layer == "conv2d":
  151. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  152. elif input_layer == "conv2d2":
  153. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  154. elif input_layer == "conv2d6":
  155. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  156. elif input_layer == "conv2d8":
  157. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  158. elif input_layer == "embed":
  159. self.embed = torch.nn.Sequential(
  160. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  161. SinusoidalPositionEncoder(),
  162. )
  163. elif input_layer is None:
  164. if input_size == output_size:
  165. self.embed = None
  166. else:
  167. self.embed = torch.nn.Linear(input_size, output_size)
  168. elif input_layer == "pe":
  169. self.embed = SinusoidalPositionEncoder()
  170. elif input_layer == "pe_online":
  171. self.embed = StreamSinusoidalPositionEncoder()
  172. else:
  173. raise ValueError("unknown input_layer: " + input_layer)
  174. self.normalize_before = normalize_before
  175. if positionwise_layer_type == "linear":
  176. positionwise_layer = PositionwiseFeedForward
  177. positionwise_layer_args = (
  178. output_size,
  179. linear_units,
  180. dropout_rate,
  181. )
  182. elif positionwise_layer_type == "conv1d":
  183. positionwise_layer = MultiLayeredConv1d
  184. positionwise_layer_args = (
  185. output_size,
  186. linear_units,
  187. positionwise_conv_kernel_size,
  188. dropout_rate,
  189. )
  190. elif positionwise_layer_type == "conv1d-linear":
  191. positionwise_layer = Conv1dLinear
  192. positionwise_layer_args = (
  193. output_size,
  194. linear_units,
  195. positionwise_conv_kernel_size,
  196. dropout_rate,
  197. )
  198. else:
  199. raise NotImplementedError("Support only linear or conv1d.")
  200. if selfattention_layer_type == "selfattn":
  201. encoder_selfattn_layer = MultiHeadedAttention
  202. encoder_selfattn_layer_args = (
  203. attention_heads,
  204. output_size,
  205. attention_dropout_rate,
  206. )
  207. elif selfattention_layer_type == "sanm":
  208. encoder_selfattn_layer = MultiHeadedAttentionSANM
  209. encoder_selfattn_layer_args0 = (
  210. attention_heads,
  211. input_size,
  212. output_size,
  213. attention_dropout_rate,
  214. kernel_size,
  215. sanm_shfit,
  216. )
  217. encoder_selfattn_layer_args = (
  218. attention_heads,
  219. output_size,
  220. output_size,
  221. attention_dropout_rate,
  222. kernel_size,
  223. sanm_shfit,
  224. )
  225. self.encoders0 = repeat(
  226. 1,
  227. lambda lnum: EncoderLayerSANM(
  228. input_size,
  229. output_size,
  230. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  231. positionwise_layer(*positionwise_layer_args),
  232. dropout_rate,
  233. normalize_before,
  234. concat_after,
  235. ),
  236. )
  237. self.encoders = repeat(
  238. num_blocks-1,
  239. lambda lnum: EncoderLayerSANM(
  240. output_size,
  241. output_size,
  242. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  243. positionwise_layer(*positionwise_layer_args),
  244. dropout_rate,
  245. normalize_before,
  246. concat_after,
  247. ),
  248. )
  249. if self.normalize_before:
  250. self.after_norm = LayerNorm(output_size)
  251. self.interctc_layer_idx = interctc_layer_idx
  252. if len(interctc_layer_idx) > 0:
  253. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  254. self.interctc_use_conditioning = interctc_use_conditioning
  255. self.conditioning_layer = None
  256. self.dropout = nn.Dropout(dropout_rate)
  257. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  258. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  259. def output_size(self) -> int:
  260. return self._output_size
  261. def forward(
  262. self,
  263. xs_pad: torch.Tensor,
  264. ilens: torch.Tensor,
  265. prev_states: torch.Tensor = None,
  266. ctc: CTC = None,
  267. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  268. """Embed positions in tensor.
  269. Args:
  270. xs_pad: input tensor (B, L, D)
  271. ilens: input length (B)
  272. prev_states: Not to be used now.
  273. Returns:
  274. position embedded tensor and mask
  275. """
  276. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  277. xs_pad = xs_pad * self.output_size()**0.5
  278. if self.embed is None:
  279. xs_pad = xs_pad
  280. elif (
  281. isinstance(self.embed, Conv2dSubsampling)
  282. or isinstance(self.embed, Conv2dSubsampling2)
  283. or isinstance(self.embed, Conv2dSubsampling6)
  284. or isinstance(self.embed, Conv2dSubsampling8)
  285. ):
  286. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  287. if short_status:
  288. raise TooShortUttError(
  289. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  290. + f"(it needs more than {limit_size} frames), return empty results",
  291. xs_pad.size(1),
  292. limit_size,
  293. )
  294. xs_pad, masks = self.embed(xs_pad, masks)
  295. else:
  296. xs_pad = self.embed(xs_pad)
  297. # xs_pad = self.dropout(xs_pad)
  298. encoder_outs = self.encoders0(xs_pad, masks)
  299. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  300. intermediate_outs = []
  301. if len(self.interctc_layer_idx) == 0:
  302. encoder_outs = self.encoders(xs_pad, masks)
  303. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  304. else:
  305. for layer_idx, encoder_layer in enumerate(self.encoders):
  306. encoder_outs = encoder_layer(xs_pad, masks)
  307. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  308. if layer_idx + 1 in self.interctc_layer_idx:
  309. encoder_out = xs_pad
  310. # intermediate outputs are also normalized
  311. if self.normalize_before:
  312. encoder_out = self.after_norm(encoder_out)
  313. intermediate_outs.append((layer_idx + 1, encoder_out))
  314. if self.interctc_use_conditioning:
  315. ctc_out = ctc.softmax(encoder_out)
  316. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  317. if self.normalize_before:
  318. xs_pad = self.after_norm(xs_pad)
  319. olens = masks.squeeze(1).sum(1)
  320. if len(intermediate_outs) > 0:
  321. return (xs_pad, intermediate_outs), olens, None
  322. return xs_pad, olens, None
  323. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  324. if len(cache) == 0:
  325. return feats
  326. cache["feats"] = to_device(cache["feats"], device=feats.device)
  327. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  328. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  329. return overlap_feats
  330. def forward_chunk(self,
  331. xs_pad: torch.Tensor,
  332. ilens: torch.Tensor,
  333. cache: dict = None,
  334. ctc: CTC = None,
  335. ):
  336. xs_pad *= self.output_size() ** 0.5
  337. if self.embed is None:
  338. xs_pad = xs_pad
  339. else:
  340. xs_pad = self.embed(xs_pad, cache)
  341. if cache["tail_chunk"]:
  342. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  343. else:
  344. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  345. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  346. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  347. intermediate_outs = []
  348. if len(self.interctc_layer_idx) == 0:
  349. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  350. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  351. else:
  352. for layer_idx, encoder_layer in enumerate(self.encoders):
  353. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  354. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  355. if layer_idx + 1 in self.interctc_layer_idx:
  356. encoder_out = xs_pad
  357. # intermediate outputs are also normalized
  358. if self.normalize_before:
  359. encoder_out = self.after_norm(encoder_out)
  360. intermediate_outs.append((layer_idx + 1, encoder_out))
  361. if self.interctc_use_conditioning:
  362. ctc_out = ctc.softmax(encoder_out)
  363. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  364. if self.normalize_before:
  365. xs_pad = self.after_norm(xs_pad)
  366. if len(intermediate_outs) > 0:
  367. return (xs_pad, intermediate_outs), None, None
  368. return xs_pad, ilens, None
  369. def gen_tf2torch_map_dict(self):
  370. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  371. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  372. map_dict_local = {
  373. ## encoder
  374. # cicd
  375. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  376. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  377. "squeeze": None,
  378. "transpose": None,
  379. }, # (256,),(256,)
  380. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  381. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  382. "squeeze": None,
  383. "transpose": None,
  384. }, # (256,),(256,)
  385. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  386. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  387. "squeeze": 0,
  388. "transpose": (1, 0),
  389. }, # (768,256),(1,256,768)
  390. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  391. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  392. "squeeze": None,
  393. "transpose": None,
  394. }, # (768,),(768,)
  395. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  396. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  397. "squeeze": 0,
  398. "transpose": (1, 2, 0),
  399. }, # (256,1,31),(1,31,256,1)
  400. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  401. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  402. "squeeze": 0,
  403. "transpose": (1, 0),
  404. }, # (256,256),(1,256,256)
  405. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  406. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  407. "squeeze": None,
  408. "transpose": None,
  409. }, # (256,),(256,)
  410. # ffn
  411. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  412. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  413. "squeeze": None,
  414. "transpose": None,
  415. }, # (256,),(256,)
  416. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  417. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  418. "squeeze": None,
  419. "transpose": None,
  420. }, # (256,),(256,)
  421. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  422. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  423. "squeeze": 0,
  424. "transpose": (1, 0),
  425. }, # (1024,256),(1,256,1024)
  426. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  427. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  428. "squeeze": None,
  429. "transpose": None,
  430. }, # (1024,),(1024,)
  431. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  432. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  433. "squeeze": 0,
  434. "transpose": (1, 0),
  435. }, # (256,1024),(1,1024,256)
  436. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  437. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  438. "squeeze": None,
  439. "transpose": None,
  440. }, # (256,),(256,)
  441. # out norm
  442. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  443. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  444. "squeeze": None,
  445. "transpose": None,
  446. }, # (256,),(256,)
  447. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  448. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  449. "squeeze": None,
  450. "transpose": None,
  451. }, # (256,),(256,)
  452. }
  453. return map_dict_local
  454. def convert_tf2torch(self,
  455. var_dict_tf,
  456. var_dict_torch,
  457. ):
  458. map_dict = self.gen_tf2torch_map_dict()
  459. var_dict_torch_update = dict()
  460. for name in sorted(var_dict_torch.keys(), reverse=False):
  461. names = name.split('.')
  462. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  463. if names[1] == "encoders0":
  464. layeridx = int(names[2])
  465. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  466. name_q = name_q.replace("encoders0", "encoders")
  467. layeridx_bias = 0
  468. layeridx += layeridx_bias
  469. if name_q in map_dict.keys():
  470. name_v = map_dict[name_q]["name"]
  471. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  472. data_tf = var_dict_tf[name_tf]
  473. if map_dict[name_q]["squeeze"] is not None:
  474. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  475. if map_dict[name_q]["transpose"] is not None:
  476. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  477. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  478. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  479. var_dict_torch[
  480. name].size(),
  481. data_tf.size())
  482. var_dict_torch_update[name] = data_tf
  483. logging.info(
  484. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  485. var_dict_tf[name_tf].shape))
  486. elif names[1] == "encoders":
  487. layeridx = int(names[2])
  488. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  489. layeridx_bias = 1
  490. layeridx += layeridx_bias
  491. if name_q in map_dict.keys():
  492. name_v = map_dict[name_q]["name"]
  493. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  494. data_tf = var_dict_tf[name_tf]
  495. if map_dict[name_q]["squeeze"] is not None:
  496. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  497. if map_dict[name_q]["transpose"] is not None:
  498. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  499. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  500. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  501. var_dict_torch[
  502. name].size(),
  503. data_tf.size())
  504. var_dict_torch_update[name] = data_tf
  505. logging.info(
  506. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  507. var_dict_tf[name_tf].shape))
  508. elif names[1] == "after_norm":
  509. name_tf = map_dict[name]["name"]
  510. data_tf = var_dict_tf[name_tf]
  511. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  512. var_dict_torch_update[name] = data_tf
  513. logging.info(
  514. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  515. var_dict_tf[name_tf].shape))
  516. return var_dict_torch_update
  517. class SANMEncoderChunkOpt(AbsEncoder):
  518. """
  519. Author: Speech Lab of DAMO Academy, Alibaba Group
  520. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  521. https://arxiv.org/abs/2006.01713
  522. """
  523. def __init__(
  524. self,
  525. input_size: int,
  526. output_size: int = 256,
  527. attention_heads: int = 4,
  528. linear_units: int = 2048,
  529. num_blocks: int = 6,
  530. dropout_rate: float = 0.1,
  531. positional_dropout_rate: float = 0.1,
  532. attention_dropout_rate: float = 0.0,
  533. input_layer: Optional[str] = "conv2d",
  534. pos_enc_class=SinusoidalPositionEncoder,
  535. normalize_before: bool = True,
  536. concat_after: bool = False,
  537. positionwise_layer_type: str = "linear",
  538. positionwise_conv_kernel_size: int = 1,
  539. padding_idx: int = -1,
  540. interctc_layer_idx: List[int] = [],
  541. interctc_use_conditioning: bool = False,
  542. kernel_size: int = 11,
  543. sanm_shfit: int = 0,
  544. selfattention_layer_type: str = "sanm",
  545. chunk_size: Union[int, Sequence[int]] = (16,),
  546. stride: Union[int, Sequence[int]] = (10,),
  547. pad_left: Union[int, Sequence[int]] = (0,),
  548. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  549. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  550. tf2torch_tensor_name_prefix_torch: str = "encoder",
  551. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  552. ):
  553. assert check_argument_types()
  554. super().__init__()
  555. self._output_size = output_size
  556. if input_layer == "linear":
  557. self.embed = torch.nn.Sequential(
  558. torch.nn.Linear(input_size, output_size),
  559. torch.nn.LayerNorm(output_size),
  560. torch.nn.Dropout(dropout_rate),
  561. torch.nn.ReLU(),
  562. pos_enc_class(output_size, positional_dropout_rate),
  563. )
  564. elif input_layer == "conv2d":
  565. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  566. elif input_layer == "conv2d2":
  567. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  568. elif input_layer == "conv2d6":
  569. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  570. elif input_layer == "conv2d8":
  571. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  572. elif input_layer == "embed":
  573. self.embed = torch.nn.Sequential(
  574. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  575. pos_enc_class(output_size, positional_dropout_rate),
  576. )
  577. elif input_layer is None:
  578. if input_size == output_size:
  579. self.embed = None
  580. else:
  581. self.embed = torch.nn.Linear(input_size, output_size)
  582. elif input_layer == "pe":
  583. self.embed = SinusoidalPositionEncoder()
  584. elif input_layer == "pe_online":
  585. self.embed = StreamSinusoidalPositionEncoder()
  586. else:
  587. raise ValueError("unknown input_layer: " + input_layer)
  588. self.normalize_before = normalize_before
  589. if positionwise_layer_type == "linear":
  590. positionwise_layer = PositionwiseFeedForward
  591. positionwise_layer_args = (
  592. output_size,
  593. linear_units,
  594. dropout_rate,
  595. )
  596. elif positionwise_layer_type == "conv1d":
  597. positionwise_layer = MultiLayeredConv1d
  598. positionwise_layer_args = (
  599. output_size,
  600. linear_units,
  601. positionwise_conv_kernel_size,
  602. dropout_rate,
  603. )
  604. elif positionwise_layer_type == "conv1d-linear":
  605. positionwise_layer = Conv1dLinear
  606. positionwise_layer_args = (
  607. output_size,
  608. linear_units,
  609. positionwise_conv_kernel_size,
  610. dropout_rate,
  611. )
  612. else:
  613. raise NotImplementedError("Support only linear or conv1d.")
  614. if selfattention_layer_type == "selfattn":
  615. encoder_selfattn_layer = MultiHeadedAttention
  616. encoder_selfattn_layer_args = (
  617. attention_heads,
  618. output_size,
  619. attention_dropout_rate,
  620. )
  621. elif selfattention_layer_type == "sanm":
  622. encoder_selfattn_layer = MultiHeadedAttentionSANM
  623. encoder_selfattn_layer_args0 = (
  624. attention_heads,
  625. input_size,
  626. output_size,
  627. attention_dropout_rate,
  628. kernel_size,
  629. sanm_shfit,
  630. )
  631. encoder_selfattn_layer_args = (
  632. attention_heads,
  633. output_size,
  634. output_size,
  635. attention_dropout_rate,
  636. kernel_size,
  637. sanm_shfit,
  638. )
  639. self.encoders0 = repeat(
  640. 1,
  641. lambda lnum: EncoderLayerSANM(
  642. input_size,
  643. output_size,
  644. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  645. positionwise_layer(*positionwise_layer_args),
  646. dropout_rate,
  647. normalize_before,
  648. concat_after,
  649. ),
  650. )
  651. self.encoders = repeat(
  652. num_blocks - 1,
  653. lambda lnum: EncoderLayerSANM(
  654. output_size,
  655. output_size,
  656. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  657. positionwise_layer(*positionwise_layer_args),
  658. dropout_rate,
  659. normalize_before,
  660. concat_after,
  661. ),
  662. )
  663. if self.normalize_before:
  664. self.after_norm = LayerNorm(output_size)
  665. self.interctc_layer_idx = interctc_layer_idx
  666. if len(interctc_layer_idx) > 0:
  667. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  668. self.interctc_use_conditioning = interctc_use_conditioning
  669. self.conditioning_layer = None
  670. shfit_fsmn = (kernel_size - 1) // 2
  671. self.overlap_chunk_cls = overlap_chunk(
  672. chunk_size=chunk_size,
  673. stride=stride,
  674. pad_left=pad_left,
  675. shfit_fsmn=shfit_fsmn,
  676. encoder_att_look_back_factor=encoder_att_look_back_factor,
  677. decoder_att_look_back_factor=decoder_att_look_back_factor,
  678. )
  679. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  680. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  681. def output_size(self) -> int:
  682. return self._output_size
  683. def forward(
  684. self,
  685. xs_pad: torch.Tensor,
  686. ilens: torch.Tensor,
  687. prev_states: torch.Tensor = None,
  688. ctc: CTC = None,
  689. ind: int = 0,
  690. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  691. """Embed positions in tensor.
  692. Args:
  693. xs_pad: input tensor (B, L, D)
  694. ilens: input length (B)
  695. prev_states: Not to be used now.
  696. Returns:
  697. position embedded tensor and mask
  698. """
  699. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  700. xs_pad *= self.output_size() ** 0.5
  701. if self.embed is None:
  702. xs_pad = xs_pad
  703. elif (
  704. isinstance(self.embed, Conv2dSubsampling)
  705. or isinstance(self.embed, Conv2dSubsampling2)
  706. or isinstance(self.embed, Conv2dSubsampling6)
  707. or isinstance(self.embed, Conv2dSubsampling8)
  708. ):
  709. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  710. if short_status:
  711. raise TooShortUttError(
  712. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  713. + f"(it needs more than {limit_size} frames), return empty results",
  714. xs_pad.size(1),
  715. limit_size,
  716. )
  717. xs_pad, masks = self.embed(xs_pad, masks)
  718. else:
  719. xs_pad = self.embed(xs_pad)
  720. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  721. if self.overlap_chunk_cls is not None:
  722. ilens = masks.squeeze(1).sum(1)
  723. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  724. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  725. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  726. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  727. dtype=xs_pad.dtype)
  728. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  729. xs_pad.size(0),
  730. dtype=xs_pad.dtype)
  731. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  732. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  733. intermediate_outs = []
  734. if len(self.interctc_layer_idx) == 0:
  735. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  736. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  737. else:
  738. for layer_idx, encoder_layer in enumerate(self.encoders):
  739. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  740. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  741. if layer_idx + 1 in self.interctc_layer_idx:
  742. encoder_out = xs_pad
  743. # intermediate outputs are also normalized
  744. if self.normalize_before:
  745. encoder_out = self.after_norm(encoder_out)
  746. intermediate_outs.append((layer_idx + 1, encoder_out))
  747. if self.interctc_use_conditioning:
  748. ctc_out = ctc.softmax(encoder_out)
  749. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  750. if self.normalize_before:
  751. xs_pad = self.after_norm(xs_pad)
  752. olens = masks.squeeze(1).sum(1)
  753. if len(intermediate_outs) > 0:
  754. return (xs_pad, intermediate_outs), olens, None
  755. return xs_pad, olens, None
  756. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  757. if len(cache) == 0:
  758. return feats
  759. cache["feats"] = to_device(cache["feats"], device=feats.device)
  760. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  761. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  762. return overlap_feats
  763. def forward_chunk(self,
  764. xs_pad: torch.Tensor,
  765. ilens: torch.Tensor,
  766. cache: dict = None,
  767. ctc: CTC = None,
  768. ):
  769. xs_pad *= self.output_size() ** 0.5
  770. if self.embed is None:
  771. xs_pad = xs_pad
  772. else:
  773. xs_pad = self.embed(xs_pad, cache)
  774. if cache["tail_chunk"]:
  775. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  776. else:
  777. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  778. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  779. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  780. intermediate_outs = []
  781. if len(self.interctc_layer_idx) == 0:
  782. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  783. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  784. else:
  785. for layer_idx, encoder_layer in enumerate(self.encoders):
  786. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  787. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  788. if layer_idx + 1 in self.interctc_layer_idx:
  789. encoder_out = xs_pad
  790. # intermediate outputs are also normalized
  791. if self.normalize_before:
  792. encoder_out = self.after_norm(encoder_out)
  793. intermediate_outs.append((layer_idx + 1, encoder_out))
  794. if self.interctc_use_conditioning:
  795. ctc_out = ctc.softmax(encoder_out)
  796. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  797. if self.normalize_before:
  798. xs_pad = self.after_norm(xs_pad)
  799. if len(intermediate_outs) > 0:
  800. return (xs_pad, intermediate_outs), None, None
  801. return xs_pad, ilens, None
  802. def gen_tf2torch_map_dict(self):
  803. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  804. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  805. map_dict_local = {
  806. ## encoder
  807. # cicd
  808. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  809. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  810. "squeeze": None,
  811. "transpose": None,
  812. }, # (256,),(256,)
  813. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  814. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  815. "squeeze": None,
  816. "transpose": None,
  817. }, # (256,),(256,)
  818. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  819. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  820. "squeeze": 0,
  821. "transpose": (1, 0),
  822. }, # (768,256),(1,256,768)
  823. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  824. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  825. "squeeze": None,
  826. "transpose": None,
  827. }, # (768,),(768,)
  828. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  829. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  830. "squeeze": 0,
  831. "transpose": (1, 2, 0),
  832. }, # (256,1,31),(1,31,256,1)
  833. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  834. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  835. "squeeze": 0,
  836. "transpose": (1, 0),
  837. }, # (256,256),(1,256,256)
  838. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  839. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  840. "squeeze": None,
  841. "transpose": None,
  842. }, # (256,),(256,)
  843. # ffn
  844. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  845. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  846. "squeeze": None,
  847. "transpose": None,
  848. }, # (256,),(256,)
  849. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  850. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  851. "squeeze": None,
  852. "transpose": None,
  853. }, # (256,),(256,)
  854. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  855. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  856. "squeeze": 0,
  857. "transpose": (1, 0),
  858. }, # (1024,256),(1,256,1024)
  859. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  860. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  861. "squeeze": None,
  862. "transpose": None,
  863. }, # (1024,),(1024,)
  864. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  865. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  866. "squeeze": 0,
  867. "transpose": (1, 0),
  868. }, # (256,1024),(1,1024,256)
  869. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  870. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  871. "squeeze": None,
  872. "transpose": None,
  873. }, # (256,),(256,)
  874. # out norm
  875. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  876. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  877. "squeeze": None,
  878. "transpose": None,
  879. }, # (256,),(256,)
  880. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  881. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  882. "squeeze": None,
  883. "transpose": None,
  884. }, # (256,),(256,)
  885. }
  886. return map_dict_local
  887. def convert_tf2torch(self,
  888. var_dict_tf,
  889. var_dict_torch,
  890. ):
  891. map_dict = self.gen_tf2torch_map_dict()
  892. var_dict_torch_update = dict()
  893. for name in sorted(var_dict_torch.keys(), reverse=False):
  894. names = name.split('.')
  895. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  896. if names[1] == "encoders0":
  897. layeridx = int(names[2])
  898. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  899. name_q = name_q.replace("encoders0", "encoders")
  900. layeridx_bias = 0
  901. layeridx += layeridx_bias
  902. if name_q in map_dict.keys():
  903. name_v = map_dict[name_q]["name"]
  904. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  905. data_tf = var_dict_tf[name_tf]
  906. if map_dict[name_q]["squeeze"] is not None:
  907. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  908. if map_dict[name_q]["transpose"] is not None:
  909. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  910. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  911. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  912. var_dict_torch[
  913. name].size(),
  914. data_tf.size())
  915. var_dict_torch_update[name] = data_tf
  916. logging.info(
  917. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  918. var_dict_tf[name_tf].shape))
  919. elif names[1] == "encoders":
  920. layeridx = int(names[2])
  921. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  922. layeridx_bias = 1
  923. layeridx += layeridx_bias
  924. if name_q in map_dict.keys():
  925. name_v = map_dict[name_q]["name"]
  926. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  927. data_tf = var_dict_tf[name_tf]
  928. if map_dict[name_q]["squeeze"] is not None:
  929. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  930. if map_dict[name_q]["transpose"] is not None:
  931. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  932. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  933. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  934. var_dict_torch[
  935. name].size(),
  936. data_tf.size())
  937. var_dict_torch_update[name] = data_tf
  938. logging.info(
  939. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  940. var_dict_tf[name_tf].shape))
  941. elif names[1] == "after_norm":
  942. name_tf = map_dict[name]["name"]
  943. data_tf = var_dict_tf[name_tf]
  944. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  945. var_dict_torch_update[name] = data_tf
  946. logging.info(
  947. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  948. var_dict_tf[name_tf].shape))
  949. return var_dict_torch_update
  950. class SANMVadEncoder(AbsEncoder):
  951. """
  952. Author: Speech Lab of DAMO Academy, Alibaba Group
  953. """
  954. def __init__(
  955. self,
  956. input_size: int,
  957. output_size: int = 256,
  958. attention_heads: int = 4,
  959. linear_units: int = 2048,
  960. num_blocks: int = 6,
  961. dropout_rate: float = 0.1,
  962. positional_dropout_rate: float = 0.1,
  963. attention_dropout_rate: float = 0.0,
  964. input_layer: Optional[str] = "conv2d",
  965. pos_enc_class=SinusoidalPositionEncoder,
  966. normalize_before: bool = True,
  967. concat_after: bool = False,
  968. positionwise_layer_type: str = "linear",
  969. positionwise_conv_kernel_size: int = 1,
  970. padding_idx: int = -1,
  971. interctc_layer_idx: List[int] = [],
  972. interctc_use_conditioning: bool = False,
  973. kernel_size : int = 11,
  974. sanm_shfit : int = 0,
  975. selfattention_layer_type: str = "sanm",
  976. ):
  977. assert check_argument_types()
  978. super().__init__()
  979. self._output_size = output_size
  980. if input_layer == "linear":
  981. self.embed = torch.nn.Sequential(
  982. torch.nn.Linear(input_size, output_size),
  983. torch.nn.LayerNorm(output_size),
  984. torch.nn.Dropout(dropout_rate),
  985. torch.nn.ReLU(),
  986. pos_enc_class(output_size, positional_dropout_rate),
  987. )
  988. elif input_layer == "conv2d":
  989. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  990. elif input_layer == "conv2d2":
  991. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  992. elif input_layer == "conv2d6":
  993. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  994. elif input_layer == "conv2d8":
  995. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  996. elif input_layer == "embed":
  997. self.embed = torch.nn.Sequential(
  998. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  999. SinusoidalPositionEncoder(),
  1000. )
  1001. elif input_layer is None:
  1002. if input_size == output_size:
  1003. self.embed = None
  1004. else:
  1005. self.embed = torch.nn.Linear(input_size, output_size)
  1006. elif input_layer == "pe":
  1007. self.embed = SinusoidalPositionEncoder()
  1008. else:
  1009. raise ValueError("unknown input_layer: " + input_layer)
  1010. self.normalize_before = normalize_before
  1011. if positionwise_layer_type == "linear":
  1012. positionwise_layer = PositionwiseFeedForward
  1013. positionwise_layer_args = (
  1014. output_size,
  1015. linear_units,
  1016. dropout_rate,
  1017. )
  1018. elif positionwise_layer_type == "conv1d":
  1019. positionwise_layer = MultiLayeredConv1d
  1020. positionwise_layer_args = (
  1021. output_size,
  1022. linear_units,
  1023. positionwise_conv_kernel_size,
  1024. dropout_rate,
  1025. )
  1026. elif positionwise_layer_type == "conv1d-linear":
  1027. positionwise_layer = Conv1dLinear
  1028. positionwise_layer_args = (
  1029. output_size,
  1030. linear_units,
  1031. positionwise_conv_kernel_size,
  1032. dropout_rate,
  1033. )
  1034. else:
  1035. raise NotImplementedError("Support only linear or conv1d.")
  1036. if selfattention_layer_type == "selfattn":
  1037. encoder_selfattn_layer = MultiHeadedAttention
  1038. encoder_selfattn_layer_args = (
  1039. attention_heads,
  1040. output_size,
  1041. attention_dropout_rate,
  1042. )
  1043. elif selfattention_layer_type == "sanm":
  1044. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  1045. encoder_selfattn_layer_args0 = (
  1046. attention_heads,
  1047. input_size,
  1048. output_size,
  1049. attention_dropout_rate,
  1050. kernel_size,
  1051. sanm_shfit,
  1052. )
  1053. encoder_selfattn_layer_args = (
  1054. attention_heads,
  1055. output_size,
  1056. output_size,
  1057. attention_dropout_rate,
  1058. kernel_size,
  1059. sanm_shfit,
  1060. )
  1061. self.encoders0 = repeat(
  1062. 1,
  1063. lambda lnum: EncoderLayerSANM(
  1064. input_size,
  1065. output_size,
  1066. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  1067. positionwise_layer(*positionwise_layer_args),
  1068. dropout_rate,
  1069. normalize_before,
  1070. concat_after,
  1071. ),
  1072. )
  1073. self.encoders = repeat(
  1074. num_blocks-1,
  1075. lambda lnum: EncoderLayerSANM(
  1076. output_size,
  1077. output_size,
  1078. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  1079. positionwise_layer(*positionwise_layer_args),
  1080. dropout_rate,
  1081. normalize_before,
  1082. concat_after,
  1083. ),
  1084. )
  1085. if self.normalize_before:
  1086. self.after_norm = LayerNorm(output_size)
  1087. self.interctc_layer_idx = interctc_layer_idx
  1088. if len(interctc_layer_idx) > 0:
  1089. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  1090. self.interctc_use_conditioning = interctc_use_conditioning
  1091. self.conditioning_layer = None
  1092. self.dropout = nn.Dropout(dropout_rate)
  1093. def output_size(self) -> int:
  1094. return self._output_size
  1095. def forward(
  1096. self,
  1097. xs_pad: torch.Tensor,
  1098. ilens: torch.Tensor,
  1099. vad_indexes: torch.Tensor,
  1100. prev_states: torch.Tensor = None,
  1101. ctc: CTC = None,
  1102. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  1103. """Embed positions in tensor.
  1104. Args:
  1105. xs_pad: input tensor (B, L, D)
  1106. ilens: input length (B)
  1107. prev_states: Not to be used now.
  1108. Returns:
  1109. position embedded tensor and mask
  1110. """
  1111. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  1112. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  1113. no_future_masks = masks & sub_masks
  1114. xs_pad *= self.output_size()**0.5
  1115. if self.embed is None:
  1116. xs_pad = xs_pad
  1117. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  1118. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  1119. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  1120. if short_status:
  1121. raise TooShortUttError(
  1122. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  1123. f"(it needs more than {limit_size} frames), return empty results",
  1124. xs_pad.size(1),
  1125. limit_size,
  1126. )
  1127. xs_pad, masks = self.embed(xs_pad, masks)
  1128. else:
  1129. xs_pad = self.embed(xs_pad)
  1130. # xs_pad = self.dropout(xs_pad)
  1131. mask_tup0 = [masks, no_future_masks]
  1132. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  1133. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  1134. intermediate_outs = []
  1135. for layer_idx, encoder_layer in enumerate(self.encoders):
  1136. if layer_idx + 1 == len(self.encoders):
  1137. # This is last layer.
  1138. coner_mask = torch.ones(masks.size(0),
  1139. masks.size(-1),
  1140. masks.size(-1),
  1141. device=xs_pad.device,
  1142. dtype=torch.bool)
  1143. for word_index, length in enumerate(ilens):
  1144. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  1145. vad_indexes[word_index],
  1146. device=xs_pad.device)
  1147. layer_mask = masks & coner_mask
  1148. else:
  1149. layer_mask = no_future_masks
  1150. mask_tup1 = [masks, layer_mask]
  1151. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  1152. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  1153. if self.normalize_before:
  1154. xs_pad = self.after_norm(xs_pad)
  1155. olens = masks.squeeze(1).sum(1)
  1156. if len(intermediate_outs) > 0:
  1157. return (xs_pad, intermediate_outs), olens, None
  1158. return xs_pad, olens, None