sanm_encoder.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  10. from typeguard import check_argument_types
  11. import numpy as np
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  14. from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  15. from funasr.modules.layer_norm import LayerNorm
  16. from funasr.modules.multi_layer_conv import Conv1dLinear
  17. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  18. from funasr.modules.positionwise_feed_forward import (
  19. PositionwiseFeedForward, # noqa: H301
  20. )
  21. from funasr.modules.repeat import repeat
  22. from funasr.modules.subsampling import Conv2dSubsampling
  23. from funasr.modules.subsampling import Conv2dSubsampling2
  24. from funasr.modules.subsampling import Conv2dSubsampling6
  25. from funasr.modules.subsampling import Conv2dSubsampling8
  26. from funasr.modules.subsampling import TooShortUttError
  27. from funasr.modules.subsampling import check_short_utt
  28. from funasr.models.ctc import CTC
  29. from funasr.models.encoder.abs_encoder import AbsEncoder
  30. from funasr.modules.mask import subsequent_mask, vad_mask
  31. class EncoderLayerSANM(nn.Module):
  32. def __init__(
  33. self,
  34. in_size,
  35. size,
  36. self_attn,
  37. feed_forward,
  38. dropout_rate,
  39. normalize_before=True,
  40. concat_after=False,
  41. stochastic_depth_rate=0.0,
  42. ):
  43. """Construct an EncoderLayer object."""
  44. super(EncoderLayerSANM, self).__init__()
  45. self.self_attn = self_attn
  46. self.feed_forward = feed_forward
  47. self.norm1 = LayerNorm(in_size)
  48. self.norm2 = LayerNorm(size)
  49. self.dropout = nn.Dropout(dropout_rate)
  50. self.in_size = in_size
  51. self.size = size
  52. self.normalize_before = normalize_before
  53. self.concat_after = concat_after
  54. if self.concat_after:
  55. self.concat_linear = nn.Linear(size + size, size)
  56. self.stochastic_depth_rate = stochastic_depth_rate
  57. self.dropout_rate = dropout_rate
  58. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  59. """Compute encoded features.
  60. Args:
  61. x_input (torch.Tensor): Input tensor (#batch, time, size).
  62. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  63. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  64. Returns:
  65. torch.Tensor: Output tensor (#batch, time, size).
  66. torch.Tensor: Mask tensor (#batch, time).
  67. """
  68. skip_layer = False
  69. # with stochastic depth, residual connection `x + f(x)` becomes
  70. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  71. stoch_layer_coeff = 1.0
  72. if self.training and self.stochastic_depth_rate > 0:
  73. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  74. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  75. if skip_layer:
  76. if cache is not None:
  77. x = torch.cat([cache, x], dim=1)
  78. return x, mask
  79. residual = x
  80. if self.normalize_before:
  81. x = self.norm1(x)
  82. if self.concat_after:
  83. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  84. if self.in_size == self.size:
  85. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  86. else:
  87. x = stoch_layer_coeff * self.concat_linear(x_concat)
  88. else:
  89. if self.in_size == self.size:
  90. x = residual + stoch_layer_coeff * self.dropout(
  91. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  92. )
  93. else:
  94. x = stoch_layer_coeff * self.dropout(
  95. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  96. )
  97. if not self.normalize_before:
  98. x = self.norm1(x)
  99. residual = x
  100. if self.normalize_before:
  101. x = self.norm2(x)
  102. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  103. if not self.normalize_before:
  104. x = self.norm2(x)
  105. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  106. class SANMEncoder(AbsEncoder):
  107. """
  108. Author: Speech Lab of DAMO Academy, Alibaba Group
  109. San-m: Memory equipped self-attention for end-to-end speech recognition
  110. https://arxiv.org/abs/2006.01713
  111. """
  112. def __init__(
  113. self,
  114. input_size: int,
  115. output_size: int = 256,
  116. attention_heads: int = 4,
  117. linear_units: int = 2048,
  118. num_blocks: int = 6,
  119. dropout_rate: float = 0.1,
  120. positional_dropout_rate: float = 0.1,
  121. attention_dropout_rate: float = 0.0,
  122. input_layer: Optional[str] = "conv2d",
  123. pos_enc_class=SinusoidalPositionEncoder,
  124. normalize_before: bool = True,
  125. concat_after: bool = False,
  126. positionwise_layer_type: str = "linear",
  127. positionwise_conv_kernel_size: int = 1,
  128. padding_idx: int = -1,
  129. interctc_layer_idx: List[int] = [],
  130. interctc_use_conditioning: bool = False,
  131. kernel_size : int = 11,
  132. sanm_shfit : int = 0,
  133. selfattention_layer_type: str = "sanm",
  134. tf2torch_tensor_name_prefix_torch: str = "encoder",
  135. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  136. ):
  137. assert check_argument_types()
  138. super().__init__()
  139. self._output_size = output_size
  140. if input_layer == "linear":
  141. self.embed = torch.nn.Sequential(
  142. torch.nn.Linear(input_size, output_size),
  143. torch.nn.LayerNorm(output_size),
  144. torch.nn.Dropout(dropout_rate),
  145. torch.nn.ReLU(),
  146. pos_enc_class(output_size, positional_dropout_rate),
  147. )
  148. elif input_layer == "conv2d":
  149. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  150. elif input_layer == "conv2d2":
  151. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  152. elif input_layer == "conv2d6":
  153. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  154. elif input_layer == "conv2d8":
  155. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  156. elif input_layer == "embed":
  157. self.embed = torch.nn.Sequential(
  158. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  159. SinusoidalPositionEncoder(),
  160. )
  161. elif input_layer is None:
  162. if input_size == output_size:
  163. self.embed = None
  164. else:
  165. self.embed = torch.nn.Linear(input_size, output_size)
  166. elif input_layer == "pe":
  167. self.embed = SinusoidalPositionEncoder()
  168. elif input_layer == "pe_online":
  169. self.embed = StreamSinusoidalPositionEncoder()
  170. else:
  171. raise ValueError("unknown input_layer: " + input_layer)
  172. self.normalize_before = normalize_before
  173. if positionwise_layer_type == "linear":
  174. positionwise_layer = PositionwiseFeedForward
  175. positionwise_layer_args = (
  176. output_size,
  177. linear_units,
  178. dropout_rate,
  179. )
  180. elif positionwise_layer_type == "conv1d":
  181. positionwise_layer = MultiLayeredConv1d
  182. positionwise_layer_args = (
  183. output_size,
  184. linear_units,
  185. positionwise_conv_kernel_size,
  186. dropout_rate,
  187. )
  188. elif positionwise_layer_type == "conv1d-linear":
  189. positionwise_layer = Conv1dLinear
  190. positionwise_layer_args = (
  191. output_size,
  192. linear_units,
  193. positionwise_conv_kernel_size,
  194. dropout_rate,
  195. )
  196. else:
  197. raise NotImplementedError("Support only linear or conv1d.")
  198. if selfattention_layer_type == "selfattn":
  199. encoder_selfattn_layer = MultiHeadedAttention
  200. encoder_selfattn_layer_args = (
  201. attention_heads,
  202. output_size,
  203. attention_dropout_rate,
  204. )
  205. elif selfattention_layer_type == "sanm":
  206. encoder_selfattn_layer = MultiHeadedAttentionSANM
  207. encoder_selfattn_layer_args0 = (
  208. attention_heads,
  209. input_size,
  210. output_size,
  211. attention_dropout_rate,
  212. kernel_size,
  213. sanm_shfit,
  214. )
  215. encoder_selfattn_layer_args = (
  216. attention_heads,
  217. output_size,
  218. output_size,
  219. attention_dropout_rate,
  220. kernel_size,
  221. sanm_shfit,
  222. )
  223. self.encoders0 = repeat(
  224. 1,
  225. lambda lnum: EncoderLayerSANM(
  226. input_size,
  227. output_size,
  228. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  229. positionwise_layer(*positionwise_layer_args),
  230. dropout_rate,
  231. normalize_before,
  232. concat_after,
  233. ),
  234. )
  235. self.encoders = repeat(
  236. num_blocks-1,
  237. lambda lnum: EncoderLayerSANM(
  238. output_size,
  239. output_size,
  240. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  241. positionwise_layer(*positionwise_layer_args),
  242. dropout_rate,
  243. normalize_before,
  244. concat_after,
  245. ),
  246. )
  247. if self.normalize_before:
  248. self.after_norm = LayerNorm(output_size)
  249. self.interctc_layer_idx = interctc_layer_idx
  250. if len(interctc_layer_idx) > 0:
  251. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  252. self.interctc_use_conditioning = interctc_use_conditioning
  253. self.conditioning_layer = None
  254. self.dropout = nn.Dropout(dropout_rate)
  255. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  256. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  257. def output_size(self) -> int:
  258. return self._output_size
  259. def forward(
  260. self,
  261. xs_pad: torch.Tensor,
  262. ilens: torch.Tensor,
  263. prev_states: torch.Tensor = None,
  264. ctc: CTC = None,
  265. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  266. """Embed positions in tensor.
  267. Args:
  268. xs_pad: input tensor (B, L, D)
  269. ilens: input length (B)
  270. prev_states: Not to be used now.
  271. Returns:
  272. position embedded tensor and mask
  273. """
  274. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  275. xs_pad = xs_pad * self.output_size()**0.5
  276. if self.embed is None:
  277. xs_pad = xs_pad
  278. elif (
  279. isinstance(self.embed, Conv2dSubsampling)
  280. or isinstance(self.embed, Conv2dSubsampling2)
  281. or isinstance(self.embed, Conv2dSubsampling6)
  282. or isinstance(self.embed, Conv2dSubsampling8)
  283. ):
  284. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  285. if short_status:
  286. raise TooShortUttError(
  287. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  288. + f"(it needs more than {limit_size} frames), return empty results",
  289. xs_pad.size(1),
  290. limit_size,
  291. )
  292. xs_pad, masks = self.embed(xs_pad, masks)
  293. else:
  294. xs_pad = self.embed(xs_pad)
  295. # xs_pad = self.dropout(xs_pad)
  296. encoder_outs = self.encoders0(xs_pad, masks)
  297. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  298. intermediate_outs = []
  299. if len(self.interctc_layer_idx) == 0:
  300. encoder_outs = self.encoders(xs_pad, masks)
  301. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  302. else:
  303. for layer_idx, encoder_layer in enumerate(self.encoders):
  304. encoder_outs = encoder_layer(xs_pad, masks)
  305. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  306. if layer_idx + 1 in self.interctc_layer_idx:
  307. encoder_out = xs_pad
  308. # intermediate outputs are also normalized
  309. if self.normalize_before:
  310. encoder_out = self.after_norm(encoder_out)
  311. intermediate_outs.append((layer_idx + 1, encoder_out))
  312. if self.interctc_use_conditioning:
  313. ctc_out = ctc.softmax(encoder_out)
  314. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  315. if self.normalize_before:
  316. xs_pad = self.after_norm(xs_pad)
  317. olens = masks.squeeze(1).sum(1)
  318. if len(intermediate_outs) > 0:
  319. return (xs_pad, intermediate_outs), olens, None
  320. return xs_pad, olens, None
  321. def forward_chunk(self,
  322. xs_pad: torch.Tensor,
  323. ilens: torch.Tensor,
  324. cache: dict = None,
  325. ctc: CTC = None,
  326. ):
  327. xs_pad *= self.output_size() ** 0.5
  328. if self.embed is None:
  329. xs_pad = xs_pad
  330. else:
  331. xs_pad = self.embed(xs_pad, cache)
  332. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  333. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  334. intermediate_outs = []
  335. if len(self.interctc_layer_idx) == 0:
  336. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  337. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  338. else:
  339. for layer_idx, encoder_layer in enumerate(self.encoders):
  340. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  341. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  342. if layer_idx + 1 in self.interctc_layer_idx:
  343. encoder_out = xs_pad
  344. # intermediate outputs are also normalized
  345. if self.normalize_before:
  346. encoder_out = self.after_norm(encoder_out)
  347. intermediate_outs.append((layer_idx + 1, encoder_out))
  348. if self.interctc_use_conditioning:
  349. ctc_out = ctc.softmax(encoder_out)
  350. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  351. if self.normalize_before:
  352. xs_pad = self.after_norm(xs_pad)
  353. if len(intermediate_outs) > 0:
  354. return (xs_pad, intermediate_outs), None, None
  355. return xs_pad, ilens, None
  356. def gen_tf2torch_map_dict(self):
  357. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  358. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  359. map_dict_local = {
  360. ## encoder
  361. # cicd
  362. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  363. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  364. "squeeze": None,
  365. "transpose": None,
  366. }, # (256,),(256,)
  367. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  368. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  369. "squeeze": None,
  370. "transpose": None,
  371. }, # (256,),(256,)
  372. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  373. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  374. "squeeze": 0,
  375. "transpose": (1, 0),
  376. }, # (768,256),(1,256,768)
  377. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  378. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  379. "squeeze": None,
  380. "transpose": None,
  381. }, # (768,),(768,)
  382. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  383. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  384. "squeeze": 0,
  385. "transpose": (1, 2, 0),
  386. }, # (256,1,31),(1,31,256,1)
  387. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  388. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  389. "squeeze": 0,
  390. "transpose": (1, 0),
  391. }, # (256,256),(1,256,256)
  392. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  393. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  394. "squeeze": None,
  395. "transpose": None,
  396. }, # (256,),(256,)
  397. # ffn
  398. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  399. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  400. "squeeze": None,
  401. "transpose": None,
  402. }, # (256,),(256,)
  403. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  404. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  405. "squeeze": None,
  406. "transpose": None,
  407. }, # (256,),(256,)
  408. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  409. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  410. "squeeze": 0,
  411. "transpose": (1, 0),
  412. }, # (1024,256),(1,256,1024)
  413. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  414. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  415. "squeeze": None,
  416. "transpose": None,
  417. }, # (1024,),(1024,)
  418. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  419. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  420. "squeeze": 0,
  421. "transpose": (1, 0),
  422. }, # (256,1024),(1,1024,256)
  423. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  424. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  425. "squeeze": None,
  426. "transpose": None,
  427. }, # (256,),(256,)
  428. # out norm
  429. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  430. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  431. "squeeze": None,
  432. "transpose": None,
  433. }, # (256,),(256,)
  434. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  435. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  436. "squeeze": None,
  437. "transpose": None,
  438. }, # (256,),(256,)
  439. }
  440. return map_dict_local
  441. def convert_tf2torch(self,
  442. var_dict_tf,
  443. var_dict_torch,
  444. ):
  445. map_dict = self.gen_tf2torch_map_dict()
  446. var_dict_torch_update = dict()
  447. for name in sorted(var_dict_torch.keys(), reverse=False):
  448. names = name.split('.')
  449. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  450. if names[1] == "encoders0":
  451. layeridx = int(names[2])
  452. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  453. name_q = name_q.replace("encoders0", "encoders")
  454. layeridx_bias = 0
  455. layeridx += layeridx_bias
  456. if name_q in map_dict.keys():
  457. name_v = map_dict[name_q]["name"]
  458. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  459. data_tf = var_dict_tf[name_tf]
  460. if map_dict[name_q]["squeeze"] is not None:
  461. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  462. if map_dict[name_q]["transpose"] is not None:
  463. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  464. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  465. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  466. var_dict_torch[
  467. name].size(),
  468. data_tf.size())
  469. var_dict_torch_update[name] = data_tf
  470. logging.info(
  471. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  472. var_dict_tf[name_tf].shape))
  473. elif names[1] == "encoders":
  474. layeridx = int(names[2])
  475. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  476. layeridx_bias = 1
  477. layeridx += layeridx_bias
  478. if name_q in map_dict.keys():
  479. name_v = map_dict[name_q]["name"]
  480. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  481. data_tf = var_dict_tf[name_tf]
  482. if map_dict[name_q]["squeeze"] is not None:
  483. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  484. if map_dict[name_q]["transpose"] is not None:
  485. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  486. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  487. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  488. var_dict_torch[
  489. name].size(),
  490. data_tf.size())
  491. var_dict_torch_update[name] = data_tf
  492. logging.info(
  493. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  494. var_dict_tf[name_tf].shape))
  495. elif names[1] == "after_norm":
  496. name_tf = map_dict[name]["name"]
  497. data_tf = var_dict_tf[name_tf]
  498. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  499. var_dict_torch_update[name] = data_tf
  500. logging.info(
  501. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  502. var_dict_tf[name_tf].shape))
  503. return var_dict_torch_update
  504. class SANMEncoderChunkOpt(AbsEncoder):
  505. """
  506. Author: Speech Lab of DAMO Academy, Alibaba Group
  507. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  508. https://arxiv.org/abs/2006.01713
  509. """
  510. def __init__(
  511. self,
  512. input_size: int,
  513. output_size: int = 256,
  514. attention_heads: int = 4,
  515. linear_units: int = 2048,
  516. num_blocks: int = 6,
  517. dropout_rate: float = 0.1,
  518. positional_dropout_rate: float = 0.1,
  519. attention_dropout_rate: float = 0.0,
  520. input_layer: Optional[str] = "conv2d",
  521. pos_enc_class=SinusoidalPositionEncoder,
  522. normalize_before: bool = True,
  523. concat_after: bool = False,
  524. positionwise_layer_type: str = "linear",
  525. positionwise_conv_kernel_size: int = 1,
  526. padding_idx: int = -1,
  527. interctc_layer_idx: List[int] = [],
  528. interctc_use_conditioning: bool = False,
  529. kernel_size: int = 11,
  530. sanm_shfit: int = 0,
  531. selfattention_layer_type: str = "sanm",
  532. chunk_size: Union[int, Sequence[int]] = (16,),
  533. stride: Union[int, Sequence[int]] = (10,),
  534. pad_left: Union[int, Sequence[int]] = (0,),
  535. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  536. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  537. tf2torch_tensor_name_prefix_torch: str = "encoder",
  538. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  539. ):
  540. assert check_argument_types()
  541. super().__init__()
  542. self._output_size = output_size
  543. if input_layer == "linear":
  544. self.embed = torch.nn.Sequential(
  545. torch.nn.Linear(input_size, output_size),
  546. torch.nn.LayerNorm(output_size),
  547. torch.nn.Dropout(dropout_rate),
  548. torch.nn.ReLU(),
  549. pos_enc_class(output_size, positional_dropout_rate),
  550. )
  551. elif input_layer == "conv2d":
  552. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  553. elif input_layer == "conv2d2":
  554. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  555. elif input_layer == "conv2d6":
  556. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  557. elif input_layer == "conv2d8":
  558. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  559. elif input_layer == "embed":
  560. self.embed = torch.nn.Sequential(
  561. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  562. pos_enc_class(output_size, positional_dropout_rate),
  563. )
  564. elif input_layer is None:
  565. if input_size == output_size:
  566. self.embed = None
  567. else:
  568. self.embed = torch.nn.Linear(input_size, output_size)
  569. elif input_layer == "pe":
  570. self.embed = SinusoidalPositionEncoder()
  571. else:
  572. raise ValueError("unknown input_layer: " + input_layer)
  573. self.normalize_before = normalize_before
  574. if positionwise_layer_type == "linear":
  575. positionwise_layer = PositionwiseFeedForward
  576. positionwise_layer_args = (
  577. output_size,
  578. linear_units,
  579. dropout_rate,
  580. )
  581. elif positionwise_layer_type == "conv1d":
  582. positionwise_layer = MultiLayeredConv1d
  583. positionwise_layer_args = (
  584. output_size,
  585. linear_units,
  586. positionwise_conv_kernel_size,
  587. dropout_rate,
  588. )
  589. elif positionwise_layer_type == "conv1d-linear":
  590. positionwise_layer = Conv1dLinear
  591. positionwise_layer_args = (
  592. output_size,
  593. linear_units,
  594. positionwise_conv_kernel_size,
  595. dropout_rate,
  596. )
  597. else:
  598. raise NotImplementedError("Support only linear or conv1d.")
  599. if selfattention_layer_type == "selfattn":
  600. encoder_selfattn_layer = MultiHeadedAttention
  601. encoder_selfattn_layer_args = (
  602. attention_heads,
  603. output_size,
  604. attention_dropout_rate,
  605. )
  606. elif selfattention_layer_type == "sanm":
  607. encoder_selfattn_layer = MultiHeadedAttentionSANM
  608. encoder_selfattn_layer_args0 = (
  609. attention_heads,
  610. input_size,
  611. output_size,
  612. attention_dropout_rate,
  613. kernel_size,
  614. sanm_shfit,
  615. )
  616. encoder_selfattn_layer_args = (
  617. attention_heads,
  618. output_size,
  619. output_size,
  620. attention_dropout_rate,
  621. kernel_size,
  622. sanm_shfit,
  623. )
  624. self.encoders0 = repeat(
  625. 1,
  626. lambda lnum: EncoderLayerSANM(
  627. input_size,
  628. output_size,
  629. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  630. positionwise_layer(*positionwise_layer_args),
  631. dropout_rate,
  632. normalize_before,
  633. concat_after,
  634. ),
  635. )
  636. self.encoders = repeat(
  637. num_blocks - 1,
  638. lambda lnum: EncoderLayerSANM(
  639. output_size,
  640. output_size,
  641. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  642. positionwise_layer(*positionwise_layer_args),
  643. dropout_rate,
  644. normalize_before,
  645. concat_after,
  646. ),
  647. )
  648. if self.normalize_before:
  649. self.after_norm = LayerNorm(output_size)
  650. self.interctc_layer_idx = interctc_layer_idx
  651. if len(interctc_layer_idx) > 0:
  652. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  653. self.interctc_use_conditioning = interctc_use_conditioning
  654. self.conditioning_layer = None
  655. shfit_fsmn = (kernel_size - 1) // 2
  656. self.overlap_chunk_cls = overlap_chunk(
  657. chunk_size=chunk_size,
  658. stride=stride,
  659. pad_left=pad_left,
  660. shfit_fsmn=shfit_fsmn,
  661. encoder_att_look_back_factor=encoder_att_look_back_factor,
  662. decoder_att_look_back_factor=decoder_att_look_back_factor,
  663. )
  664. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  665. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  666. def output_size(self) -> int:
  667. return self._output_size
  668. def forward(
  669. self,
  670. xs_pad: torch.Tensor,
  671. ilens: torch.Tensor,
  672. prev_states: torch.Tensor = None,
  673. ctc: CTC = None,
  674. ind: int = 0,
  675. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  676. """Embed positions in tensor.
  677. Args:
  678. xs_pad: input tensor (B, L, D)
  679. ilens: input length (B)
  680. prev_states: Not to be used now.
  681. Returns:
  682. position embedded tensor and mask
  683. """
  684. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  685. xs_pad *= self.output_size() ** 0.5
  686. if self.embed is None:
  687. xs_pad = xs_pad
  688. elif (
  689. isinstance(self.embed, Conv2dSubsampling)
  690. or isinstance(self.embed, Conv2dSubsampling2)
  691. or isinstance(self.embed, Conv2dSubsampling6)
  692. or isinstance(self.embed, Conv2dSubsampling8)
  693. ):
  694. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  695. if short_status:
  696. raise TooShortUttError(
  697. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  698. + f"(it needs more than {limit_size} frames), return empty results",
  699. xs_pad.size(1),
  700. limit_size,
  701. )
  702. xs_pad, masks = self.embed(xs_pad, masks)
  703. else:
  704. xs_pad = self.embed(xs_pad)
  705. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  706. if self.overlap_chunk_cls is not None:
  707. ilens = masks.squeeze(1).sum(1)
  708. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  709. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  710. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  711. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  712. dtype=xs_pad.dtype)
  713. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  714. xs_pad.size(0),
  715. dtype=xs_pad.dtype)
  716. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  717. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  718. intermediate_outs = []
  719. if len(self.interctc_layer_idx) == 0:
  720. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  721. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  722. else:
  723. for layer_idx, encoder_layer in enumerate(self.encoders):
  724. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  725. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  726. if layer_idx + 1 in self.interctc_layer_idx:
  727. encoder_out = xs_pad
  728. # intermediate outputs are also normalized
  729. if self.normalize_before:
  730. encoder_out = self.after_norm(encoder_out)
  731. intermediate_outs.append((layer_idx + 1, encoder_out))
  732. if self.interctc_use_conditioning:
  733. ctc_out = ctc.softmax(encoder_out)
  734. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  735. if self.normalize_before:
  736. xs_pad = self.after_norm(xs_pad)
  737. olens = masks.squeeze(1).sum(1)
  738. if len(intermediate_outs) > 0:
  739. return (xs_pad, intermediate_outs), olens, None
  740. return xs_pad, olens, None
  741. def gen_tf2torch_map_dict(self):
  742. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  743. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  744. map_dict_local = {
  745. ## encoder
  746. # cicd
  747. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  748. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  749. "squeeze": None,
  750. "transpose": None,
  751. }, # (256,),(256,)
  752. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  753. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  754. "squeeze": None,
  755. "transpose": None,
  756. }, # (256,),(256,)
  757. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  758. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  759. "squeeze": 0,
  760. "transpose": (1, 0),
  761. }, # (768,256),(1,256,768)
  762. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  763. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  764. "squeeze": None,
  765. "transpose": None,
  766. }, # (768,),(768,)
  767. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  768. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  769. "squeeze": 0,
  770. "transpose": (1, 2, 0),
  771. }, # (256,1,31),(1,31,256,1)
  772. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  773. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  774. "squeeze": 0,
  775. "transpose": (1, 0),
  776. }, # (256,256),(1,256,256)
  777. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  778. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  779. "squeeze": None,
  780. "transpose": None,
  781. }, # (256,),(256,)
  782. # ffn
  783. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  784. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  785. "squeeze": None,
  786. "transpose": None,
  787. }, # (256,),(256,)
  788. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  789. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  790. "squeeze": None,
  791. "transpose": None,
  792. }, # (256,),(256,)
  793. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  794. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  795. "squeeze": 0,
  796. "transpose": (1, 0),
  797. }, # (1024,256),(1,256,1024)
  798. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  799. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  800. "squeeze": None,
  801. "transpose": None,
  802. }, # (1024,),(1024,)
  803. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  804. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  805. "squeeze": 0,
  806. "transpose": (1, 0),
  807. }, # (256,1024),(1,1024,256)
  808. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  809. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  810. "squeeze": None,
  811. "transpose": None,
  812. }, # (256,),(256,)
  813. # out norm
  814. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  815. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  816. "squeeze": None,
  817. "transpose": None,
  818. }, # (256,),(256,)
  819. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  820. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  821. "squeeze": None,
  822. "transpose": None,
  823. }, # (256,),(256,)
  824. }
  825. return map_dict_local
  826. def convert_tf2torch(self,
  827. var_dict_tf,
  828. var_dict_torch,
  829. ):
  830. map_dict = self.gen_tf2torch_map_dict()
  831. var_dict_torch_update = dict()
  832. for name in sorted(var_dict_torch.keys(), reverse=False):
  833. names = name.split('.')
  834. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  835. if names[1] == "encoders0":
  836. layeridx = int(names[2])
  837. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  838. name_q = name_q.replace("encoders0", "encoders")
  839. layeridx_bias = 0
  840. layeridx += layeridx_bias
  841. if name_q in map_dict.keys():
  842. name_v = map_dict[name_q]["name"]
  843. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  844. data_tf = var_dict_tf[name_tf]
  845. if map_dict[name_q]["squeeze"] is not None:
  846. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  847. if map_dict[name_q]["transpose"] is not None:
  848. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  849. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  850. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  851. var_dict_torch[
  852. name].size(),
  853. data_tf.size())
  854. var_dict_torch_update[name] = data_tf
  855. logging.info(
  856. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  857. var_dict_tf[name_tf].shape))
  858. elif names[1] == "encoders":
  859. layeridx = int(names[2])
  860. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  861. layeridx_bias = 1
  862. layeridx += layeridx_bias
  863. if name_q in map_dict.keys():
  864. name_v = map_dict[name_q]["name"]
  865. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  866. data_tf = var_dict_tf[name_tf]
  867. if map_dict[name_q]["squeeze"] is not None:
  868. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  869. if map_dict[name_q]["transpose"] is not None:
  870. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  871. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  872. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  873. var_dict_torch[
  874. name].size(),
  875. data_tf.size())
  876. var_dict_torch_update[name] = data_tf
  877. logging.info(
  878. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  879. var_dict_tf[name_tf].shape))
  880. elif names[1] == "after_norm":
  881. name_tf = map_dict[name]["name"]
  882. data_tf = var_dict_tf[name_tf]
  883. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  884. var_dict_torch_update[name] = data_tf
  885. logging.info(
  886. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  887. var_dict_tf[name_tf].shape))
  888. return var_dict_torch_update
  889. class SANMVadEncoder(AbsEncoder):
  890. """
  891. Author: Speech Lab of DAMO Academy, Alibaba Group
  892. """
  893. def __init__(
  894. self,
  895. input_size: int,
  896. output_size: int = 256,
  897. attention_heads: int = 4,
  898. linear_units: int = 2048,
  899. num_blocks: int = 6,
  900. dropout_rate: float = 0.1,
  901. positional_dropout_rate: float = 0.1,
  902. attention_dropout_rate: float = 0.0,
  903. input_layer: Optional[str] = "conv2d",
  904. pos_enc_class=SinusoidalPositionEncoder,
  905. normalize_before: bool = True,
  906. concat_after: bool = False,
  907. positionwise_layer_type: str = "linear",
  908. positionwise_conv_kernel_size: int = 1,
  909. padding_idx: int = -1,
  910. interctc_layer_idx: List[int] = [],
  911. interctc_use_conditioning: bool = False,
  912. kernel_size : int = 11,
  913. sanm_shfit : int = 0,
  914. selfattention_layer_type: str = "sanm",
  915. ):
  916. assert check_argument_types()
  917. super().__init__()
  918. self._output_size = output_size
  919. if input_layer == "linear":
  920. self.embed = torch.nn.Sequential(
  921. torch.nn.Linear(input_size, output_size),
  922. torch.nn.LayerNorm(output_size),
  923. torch.nn.Dropout(dropout_rate),
  924. torch.nn.ReLU(),
  925. pos_enc_class(output_size, positional_dropout_rate),
  926. )
  927. elif input_layer == "conv2d":
  928. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  929. elif input_layer == "conv2d2":
  930. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  931. elif input_layer == "conv2d6":
  932. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  933. elif input_layer == "conv2d8":
  934. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  935. elif input_layer == "embed":
  936. self.embed = torch.nn.Sequential(
  937. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  938. SinusoidalPositionEncoder(),
  939. )
  940. elif input_layer is None:
  941. if input_size == output_size:
  942. self.embed = None
  943. else:
  944. self.embed = torch.nn.Linear(input_size, output_size)
  945. elif input_layer == "pe":
  946. self.embed = SinusoidalPositionEncoder()
  947. else:
  948. raise ValueError("unknown input_layer: " + input_layer)
  949. self.normalize_before = normalize_before
  950. if positionwise_layer_type == "linear":
  951. positionwise_layer = PositionwiseFeedForward
  952. positionwise_layer_args = (
  953. output_size,
  954. linear_units,
  955. dropout_rate,
  956. )
  957. elif positionwise_layer_type == "conv1d":
  958. positionwise_layer = MultiLayeredConv1d
  959. positionwise_layer_args = (
  960. output_size,
  961. linear_units,
  962. positionwise_conv_kernel_size,
  963. dropout_rate,
  964. )
  965. elif positionwise_layer_type == "conv1d-linear":
  966. positionwise_layer = Conv1dLinear
  967. positionwise_layer_args = (
  968. output_size,
  969. linear_units,
  970. positionwise_conv_kernel_size,
  971. dropout_rate,
  972. )
  973. else:
  974. raise NotImplementedError("Support only linear or conv1d.")
  975. if selfattention_layer_type == "selfattn":
  976. encoder_selfattn_layer = MultiHeadedAttention
  977. encoder_selfattn_layer_args = (
  978. attention_heads,
  979. output_size,
  980. attention_dropout_rate,
  981. )
  982. elif selfattention_layer_type == "sanm":
  983. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  984. encoder_selfattn_layer_args0 = (
  985. attention_heads,
  986. input_size,
  987. output_size,
  988. attention_dropout_rate,
  989. kernel_size,
  990. sanm_shfit,
  991. )
  992. encoder_selfattn_layer_args = (
  993. attention_heads,
  994. output_size,
  995. output_size,
  996. attention_dropout_rate,
  997. kernel_size,
  998. sanm_shfit,
  999. )
  1000. self.encoders0 = repeat(
  1001. 1,
  1002. lambda lnum: EncoderLayerSANM(
  1003. input_size,
  1004. output_size,
  1005. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  1006. positionwise_layer(*positionwise_layer_args),
  1007. dropout_rate,
  1008. normalize_before,
  1009. concat_after,
  1010. ),
  1011. )
  1012. self.encoders = repeat(
  1013. num_blocks-1,
  1014. lambda lnum: EncoderLayerSANM(
  1015. output_size,
  1016. output_size,
  1017. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  1018. positionwise_layer(*positionwise_layer_args),
  1019. dropout_rate,
  1020. normalize_before,
  1021. concat_after,
  1022. ),
  1023. )
  1024. if self.normalize_before:
  1025. self.after_norm = LayerNorm(output_size)
  1026. self.interctc_layer_idx = interctc_layer_idx
  1027. if len(interctc_layer_idx) > 0:
  1028. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  1029. self.interctc_use_conditioning = interctc_use_conditioning
  1030. self.conditioning_layer = None
  1031. self.dropout = nn.Dropout(dropout_rate)
  1032. def output_size(self) -> int:
  1033. return self._output_size
  1034. def forward(
  1035. self,
  1036. xs_pad: torch.Tensor,
  1037. ilens: torch.Tensor,
  1038. vad_indexes: torch.Tensor,
  1039. prev_states: torch.Tensor = None,
  1040. ctc: CTC = None,
  1041. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  1042. """Embed positions in tensor.
  1043. Args:
  1044. xs_pad: input tensor (B, L, D)
  1045. ilens: input length (B)
  1046. prev_states: Not to be used now.
  1047. Returns:
  1048. position embedded tensor and mask
  1049. """
  1050. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  1051. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  1052. no_future_masks = masks & sub_masks
  1053. xs_pad *= self.output_size()**0.5
  1054. if self.embed is None:
  1055. xs_pad = xs_pad
  1056. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  1057. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  1058. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  1059. if short_status:
  1060. raise TooShortUttError(
  1061. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  1062. f"(it needs more than {limit_size} frames), return empty results",
  1063. xs_pad.size(1),
  1064. limit_size,
  1065. )
  1066. xs_pad, masks = self.embed(xs_pad, masks)
  1067. else:
  1068. xs_pad = self.embed(xs_pad)
  1069. # xs_pad = self.dropout(xs_pad)
  1070. mask_tup0 = [masks, no_future_masks]
  1071. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  1072. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  1073. intermediate_outs = []
  1074. for layer_idx, encoder_layer in enumerate(self.encoders):
  1075. if layer_idx + 1 == len(self.encoders):
  1076. # This is last layer.
  1077. coner_mask = torch.ones(masks.size(0),
  1078. masks.size(-1),
  1079. masks.size(-1),
  1080. device=xs_pad.device,
  1081. dtype=torch.bool)
  1082. for word_index, length in enumerate(ilens):
  1083. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  1084. vad_indexes[word_index],
  1085. device=xs_pad.device)
  1086. layer_mask = masks & coner_mask
  1087. else:
  1088. layer_mask = no_future_masks
  1089. mask_tup1 = [masks, layer_mask]
  1090. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  1091. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  1092. if self.normalize_before:
  1093. xs_pad = self.after_norm(xs_pad)
  1094. olens = masks.squeeze(1).sum(1)
  1095. if len(intermediate_outs) > 0:
  1096. return (xs_pad, intermediate_outs), olens, None
  1097. return xs_pad, olens, None