sanm_encoder.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  11. import numpy as np
  12. from funasr.torch_utils.device_funcs import to_device
  13. from funasr.modules.nets_utils import make_pad_mask
  14. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  15. from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.multi_layer_conv import Conv1dLinear
  18. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  19. from funasr.modules.positionwise_feed_forward import (
  20. PositionwiseFeedForward, # noqa: H301
  21. )
  22. from funasr.modules.repeat import repeat
  23. from funasr.modules.subsampling import Conv2dSubsampling
  24. from funasr.modules.subsampling import Conv2dSubsampling2
  25. from funasr.modules.subsampling import Conv2dSubsampling6
  26. from funasr.modules.subsampling import Conv2dSubsampling8
  27. from funasr.modules.subsampling import TooShortUttError
  28. from funasr.modules.subsampling import check_short_utt
  29. from funasr.modules.mask import subsequent_mask, vad_mask
  30. from funasr.models.ctc import CTC
  31. from funasr.models.encoder.abs_encoder import AbsEncoder
  32. class EncoderLayerSANM(nn.Module):
  33. def __init__(
  34. self,
  35. in_size,
  36. size,
  37. self_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. stochastic_depth_rate=0.0,
  43. ):
  44. """Construct an EncoderLayer object."""
  45. super(EncoderLayerSANM, self).__init__()
  46. self.self_attn = self_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(in_size)
  49. self.norm2 = LayerNorm(size)
  50. self.dropout = nn.Dropout(dropout_rate)
  51. self.in_size = in_size
  52. self.size = size
  53. self.normalize_before = normalize_before
  54. self.concat_after = concat_after
  55. if self.concat_after:
  56. self.concat_linear = nn.Linear(size + size, size)
  57. self.stochastic_depth_rate = stochastic_depth_rate
  58. self.dropout_rate = dropout_rate
  59. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  60. """Compute encoded features.
  61. Args:
  62. x_input (torch.Tensor): Input tensor (#batch, time, size).
  63. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  64. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  65. Returns:
  66. torch.Tensor: Output tensor (#batch, time, size).
  67. torch.Tensor: Mask tensor (#batch, time).
  68. """
  69. skip_layer = False
  70. # with stochastic depth, residual connection `x + f(x)` becomes
  71. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  72. stoch_layer_coeff = 1.0
  73. if self.training and self.stochastic_depth_rate > 0:
  74. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  75. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  76. if skip_layer:
  77. if cache is not None:
  78. x = torch.cat([cache, x], dim=1)
  79. return x, mask
  80. residual = x
  81. if self.normalize_before:
  82. x = self.norm1(x)
  83. if self.concat_after:
  84. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  85. if self.in_size == self.size:
  86. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. x = stoch_layer_coeff * self.concat_linear(x_concat)
  89. else:
  90. if self.in_size == self.size:
  91. x = residual + stoch_layer_coeff * self.dropout(
  92. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  93. )
  94. else:
  95. x = stoch_layer_coeff * self.dropout(
  96. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  97. )
  98. if not self.normalize_before:
  99. x = self.norm1(x)
  100. residual = x
  101. if self.normalize_before:
  102. x = self.norm2(x)
  103. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  104. if not self.normalize_before:
  105. x = self.norm2(x)
  106. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  107. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  108. """Compute encoded features.
  109. Args:
  110. x_input (torch.Tensor): Input tensor (#batch, time, size).
  111. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  112. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  113. Returns:
  114. torch.Tensor: Output tensor (#batch, time, size).
  115. torch.Tensor: Mask tensor (#batch, time).
  116. """
  117. residual = x
  118. if self.normalize_before:
  119. x = self.norm1(x)
  120. if self.in_size == self.size:
  121. attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  122. x = residual + attn
  123. else:
  124. x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  125. if not self.normalize_before:
  126. x = self.norm1(x)
  127. residual = x
  128. if self.normalize_before:
  129. x = self.norm2(x)
  130. x = residual + self.feed_forward(x)
  131. if not self.normalize_before:
  132. x = self.norm2(x)
  133. return x, cache
  134. class SANMEncoder(AbsEncoder):
  135. """
  136. Author: Speech Lab of DAMO Academy, Alibaba Group
  137. San-m: Memory equipped self-attention for end-to-end speech recognition
  138. https://arxiv.org/abs/2006.01713
  139. """
  140. def __init__(
  141. self,
  142. input_size: int,
  143. output_size: int = 256,
  144. attention_heads: int = 4,
  145. linear_units: int = 2048,
  146. num_blocks: int = 6,
  147. dropout_rate: float = 0.1,
  148. positional_dropout_rate: float = 0.1,
  149. attention_dropout_rate: float = 0.0,
  150. input_layer: Optional[str] = "conv2d",
  151. pos_enc_class=SinusoidalPositionEncoder,
  152. normalize_before: bool = True,
  153. concat_after: bool = False,
  154. positionwise_layer_type: str = "linear",
  155. positionwise_conv_kernel_size: int = 1,
  156. padding_idx: int = -1,
  157. interctc_layer_idx: List[int] = [],
  158. interctc_use_conditioning: bool = False,
  159. kernel_size : int = 11,
  160. sanm_shfit : int = 0,
  161. lora_list: List[str] = None,
  162. lora_rank: int = 8,
  163. lora_alpha: int = 16,
  164. lora_dropout: float = 0.1,
  165. selfattention_layer_type: str = "sanm",
  166. tf2torch_tensor_name_prefix_torch: str = "encoder",
  167. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  168. ):
  169. super().__init__()
  170. self._output_size = output_size
  171. if input_layer == "linear":
  172. self.embed = torch.nn.Sequential(
  173. torch.nn.Linear(input_size, output_size),
  174. torch.nn.LayerNorm(output_size),
  175. torch.nn.Dropout(dropout_rate),
  176. torch.nn.ReLU(),
  177. pos_enc_class(output_size, positional_dropout_rate),
  178. )
  179. elif input_layer == "conv2d":
  180. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  181. elif input_layer == "conv2d2":
  182. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  183. elif input_layer == "conv2d6":
  184. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  185. elif input_layer == "conv2d8":
  186. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  187. elif input_layer == "embed":
  188. self.embed = torch.nn.Sequential(
  189. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  190. SinusoidalPositionEncoder(),
  191. )
  192. elif input_layer is None:
  193. if input_size == output_size:
  194. self.embed = None
  195. else:
  196. self.embed = torch.nn.Linear(input_size, output_size)
  197. elif input_layer == "pe":
  198. self.embed = SinusoidalPositionEncoder()
  199. elif input_layer == "pe_online":
  200. self.embed = StreamSinusoidalPositionEncoder()
  201. else:
  202. raise ValueError("unknown input_layer: " + input_layer)
  203. self.normalize_before = normalize_before
  204. if positionwise_layer_type == "linear":
  205. positionwise_layer = PositionwiseFeedForward
  206. positionwise_layer_args = (
  207. output_size,
  208. linear_units,
  209. dropout_rate,
  210. )
  211. elif positionwise_layer_type == "conv1d":
  212. positionwise_layer = MultiLayeredConv1d
  213. positionwise_layer_args = (
  214. output_size,
  215. linear_units,
  216. positionwise_conv_kernel_size,
  217. dropout_rate,
  218. )
  219. elif positionwise_layer_type == "conv1d-linear":
  220. positionwise_layer = Conv1dLinear
  221. positionwise_layer_args = (
  222. output_size,
  223. linear_units,
  224. positionwise_conv_kernel_size,
  225. dropout_rate,
  226. )
  227. else:
  228. raise NotImplementedError("Support only linear or conv1d.")
  229. if selfattention_layer_type == "selfattn":
  230. encoder_selfattn_layer = MultiHeadedAttention
  231. encoder_selfattn_layer_args = (
  232. attention_heads,
  233. output_size,
  234. attention_dropout_rate,
  235. )
  236. elif selfattention_layer_type == "sanm":
  237. encoder_selfattn_layer = MultiHeadedAttentionSANM
  238. encoder_selfattn_layer_args0 = (
  239. attention_heads,
  240. input_size,
  241. output_size,
  242. attention_dropout_rate,
  243. kernel_size,
  244. sanm_shfit,
  245. lora_list,
  246. lora_rank,
  247. lora_alpha,
  248. lora_dropout,
  249. )
  250. encoder_selfattn_layer_args = (
  251. attention_heads,
  252. output_size,
  253. output_size,
  254. attention_dropout_rate,
  255. kernel_size,
  256. sanm_shfit,
  257. lora_list,
  258. lora_rank,
  259. lora_alpha,
  260. lora_dropout,
  261. )
  262. self.encoders0 = repeat(
  263. 1,
  264. lambda lnum: EncoderLayerSANM(
  265. input_size,
  266. output_size,
  267. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  268. positionwise_layer(*positionwise_layer_args),
  269. dropout_rate,
  270. normalize_before,
  271. concat_after,
  272. ),
  273. )
  274. self.encoders = repeat(
  275. num_blocks-1,
  276. lambda lnum: EncoderLayerSANM(
  277. output_size,
  278. output_size,
  279. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  280. positionwise_layer(*positionwise_layer_args),
  281. dropout_rate,
  282. normalize_before,
  283. concat_after,
  284. ),
  285. )
  286. if self.normalize_before:
  287. self.after_norm = LayerNorm(output_size)
  288. self.interctc_layer_idx = interctc_layer_idx
  289. if len(interctc_layer_idx) > 0:
  290. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  291. self.interctc_use_conditioning = interctc_use_conditioning
  292. self.conditioning_layer = None
  293. self.dropout = nn.Dropout(dropout_rate)
  294. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  295. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  296. def output_size(self) -> int:
  297. return self._output_size
  298. def forward(
  299. self,
  300. xs_pad: torch.Tensor,
  301. ilens: torch.Tensor,
  302. prev_states: torch.Tensor = None,
  303. ctc: CTC = None,
  304. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  305. """Embed positions in tensor.
  306. Args:
  307. xs_pad: input tensor (B, L, D)
  308. ilens: input length (B)
  309. prev_states: Not to be used now.
  310. Returns:
  311. position embedded tensor and mask
  312. """
  313. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  314. xs_pad = xs_pad * self.output_size()**0.5
  315. if self.embed is None:
  316. xs_pad = xs_pad
  317. elif (
  318. isinstance(self.embed, Conv2dSubsampling)
  319. or isinstance(self.embed, Conv2dSubsampling2)
  320. or isinstance(self.embed, Conv2dSubsampling6)
  321. or isinstance(self.embed, Conv2dSubsampling8)
  322. ):
  323. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  324. if short_status:
  325. raise TooShortUttError(
  326. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  327. + f"(it needs more than {limit_size} frames), return empty results",
  328. xs_pad.size(1),
  329. limit_size,
  330. )
  331. xs_pad, masks = self.embed(xs_pad, masks)
  332. else:
  333. xs_pad = self.embed(xs_pad)
  334. # xs_pad = self.dropout(xs_pad)
  335. encoder_outs = self.encoders0(xs_pad, masks)
  336. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  337. intermediate_outs = []
  338. if len(self.interctc_layer_idx) == 0:
  339. encoder_outs = self.encoders(xs_pad, masks)
  340. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  341. else:
  342. for layer_idx, encoder_layer in enumerate(self.encoders):
  343. encoder_outs = encoder_layer(xs_pad, masks)
  344. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  345. if layer_idx + 1 in self.interctc_layer_idx:
  346. encoder_out = xs_pad
  347. # intermediate outputs are also normalized
  348. if self.normalize_before:
  349. encoder_out = self.after_norm(encoder_out)
  350. intermediate_outs.append((layer_idx + 1, encoder_out))
  351. if self.interctc_use_conditioning:
  352. ctc_out = ctc.softmax(encoder_out)
  353. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  354. if self.normalize_before:
  355. xs_pad = self.after_norm(xs_pad)
  356. olens = masks.squeeze(1).sum(1)
  357. if len(intermediate_outs) > 0:
  358. return (xs_pad, intermediate_outs), olens, None
  359. return xs_pad, olens, None
  360. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  361. if len(cache) == 0:
  362. return feats
  363. cache["feats"] = to_device(cache["feats"], device=feats.device)
  364. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  365. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  366. return overlap_feats
  367. def forward_chunk(self,
  368. xs_pad: torch.Tensor,
  369. ilens: torch.Tensor,
  370. cache: dict = None,
  371. ctc: CTC = None,
  372. ):
  373. xs_pad *= self.output_size() ** 0.5
  374. if self.embed is None:
  375. xs_pad = xs_pad
  376. else:
  377. xs_pad = self.embed(xs_pad, cache)
  378. if cache["tail_chunk"]:
  379. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  380. else:
  381. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  382. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  383. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  384. intermediate_outs = []
  385. if len(self.interctc_layer_idx) == 0:
  386. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  387. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  388. else:
  389. for layer_idx, encoder_layer in enumerate(self.encoders):
  390. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  391. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  392. if layer_idx + 1 in self.interctc_layer_idx:
  393. encoder_out = xs_pad
  394. # intermediate outputs are also normalized
  395. if self.normalize_before:
  396. encoder_out = self.after_norm(encoder_out)
  397. intermediate_outs.append((layer_idx + 1, encoder_out))
  398. if self.interctc_use_conditioning:
  399. ctc_out = ctc.softmax(encoder_out)
  400. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  401. if self.normalize_before:
  402. xs_pad = self.after_norm(xs_pad)
  403. if len(intermediate_outs) > 0:
  404. return (xs_pad, intermediate_outs), None, None
  405. return xs_pad, ilens, None
  406. def gen_tf2torch_map_dict(self):
  407. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  408. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  409. map_dict_local = {
  410. ## encoder
  411. # cicd
  412. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  413. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  414. "squeeze": None,
  415. "transpose": None,
  416. }, # (256,),(256,)
  417. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  418. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  419. "squeeze": None,
  420. "transpose": None,
  421. }, # (256,),(256,)
  422. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  423. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  424. "squeeze": 0,
  425. "transpose": (1, 0),
  426. }, # (768,256),(1,256,768)
  427. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  428. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  429. "squeeze": None,
  430. "transpose": None,
  431. }, # (768,),(768,)
  432. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  433. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  434. "squeeze": 0,
  435. "transpose": (1, 2, 0),
  436. }, # (256,1,31),(1,31,256,1)
  437. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  438. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  439. "squeeze": 0,
  440. "transpose": (1, 0),
  441. }, # (256,256),(1,256,256)
  442. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  443. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  444. "squeeze": None,
  445. "transpose": None,
  446. }, # (256,),(256,)
  447. # ffn
  448. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  449. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  450. "squeeze": None,
  451. "transpose": None,
  452. }, # (256,),(256,)
  453. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  454. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  455. "squeeze": None,
  456. "transpose": None,
  457. }, # (256,),(256,)
  458. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  459. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  460. "squeeze": 0,
  461. "transpose": (1, 0),
  462. }, # (1024,256),(1,256,1024)
  463. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  464. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  465. "squeeze": None,
  466. "transpose": None,
  467. }, # (1024,),(1024,)
  468. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  469. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  470. "squeeze": 0,
  471. "transpose": (1, 0),
  472. }, # (256,1024),(1,1024,256)
  473. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  474. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  475. "squeeze": None,
  476. "transpose": None,
  477. }, # (256,),(256,)
  478. # out norm
  479. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  480. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  481. "squeeze": None,
  482. "transpose": None,
  483. }, # (256,),(256,)
  484. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  485. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  486. "squeeze": None,
  487. "transpose": None,
  488. }, # (256,),(256,)
  489. }
  490. return map_dict_local
  491. def convert_tf2torch(self,
  492. var_dict_tf,
  493. var_dict_torch,
  494. ):
  495. map_dict = self.gen_tf2torch_map_dict()
  496. var_dict_torch_update = dict()
  497. for name in sorted(var_dict_torch.keys(), reverse=False):
  498. names = name.split('.')
  499. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  500. if names[1] == "encoders0":
  501. layeridx = int(names[2])
  502. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  503. name_q = name_q.replace("encoders0", "encoders")
  504. layeridx_bias = 0
  505. layeridx += layeridx_bias
  506. if name_q in map_dict.keys():
  507. name_v = map_dict[name_q]["name"]
  508. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  509. data_tf = var_dict_tf[name_tf]
  510. if map_dict[name_q]["squeeze"] is not None:
  511. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  512. if map_dict[name_q]["transpose"] is not None:
  513. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  514. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  515. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  516. var_dict_torch[
  517. name].size(),
  518. data_tf.size())
  519. var_dict_torch_update[name] = data_tf
  520. logging.info(
  521. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  522. var_dict_tf[name_tf].shape))
  523. elif names[1] == "encoders":
  524. layeridx = int(names[2])
  525. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  526. layeridx_bias = 1
  527. layeridx += layeridx_bias
  528. if name_q in map_dict.keys():
  529. name_v = map_dict[name_q]["name"]
  530. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  531. data_tf = var_dict_tf[name_tf]
  532. if map_dict[name_q]["squeeze"] is not None:
  533. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  534. if map_dict[name_q]["transpose"] is not None:
  535. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  536. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  537. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  538. var_dict_torch[
  539. name].size(),
  540. data_tf.size())
  541. var_dict_torch_update[name] = data_tf
  542. logging.info(
  543. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  544. var_dict_tf[name_tf].shape))
  545. elif names[1] == "after_norm":
  546. name_tf = map_dict[name]["name"]
  547. data_tf = var_dict_tf[name_tf]
  548. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  549. var_dict_torch_update[name] = data_tf
  550. logging.info(
  551. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  552. var_dict_tf[name_tf].shape))
  553. return var_dict_torch_update
  554. class SANMEncoderChunkOpt(AbsEncoder):
  555. """
  556. Author: Speech Lab of DAMO Academy, Alibaba Group
  557. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  558. https://arxiv.org/abs/2006.01713
  559. """
  560. def __init__(
  561. self,
  562. input_size: int,
  563. output_size: int = 256,
  564. attention_heads: int = 4,
  565. linear_units: int = 2048,
  566. num_blocks: int = 6,
  567. dropout_rate: float = 0.1,
  568. positional_dropout_rate: float = 0.1,
  569. attention_dropout_rate: float = 0.0,
  570. input_layer: Optional[str] = "conv2d",
  571. pos_enc_class=SinusoidalPositionEncoder,
  572. normalize_before: bool = True,
  573. concat_after: bool = False,
  574. positionwise_layer_type: str = "linear",
  575. positionwise_conv_kernel_size: int = 1,
  576. padding_idx: int = -1,
  577. interctc_layer_idx: List[int] = [],
  578. interctc_use_conditioning: bool = False,
  579. kernel_size: int = 11,
  580. sanm_shfit: int = 0,
  581. selfattention_layer_type: str = "sanm",
  582. chunk_size: Union[int, Sequence[int]] = (16,),
  583. stride: Union[int, Sequence[int]] = (10,),
  584. pad_left: Union[int, Sequence[int]] = (0,),
  585. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  586. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  587. tf2torch_tensor_name_prefix_torch: str = "encoder",
  588. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  589. ):
  590. super().__init__()
  591. self._output_size = output_size
  592. if input_layer == "linear":
  593. self.embed = torch.nn.Sequential(
  594. torch.nn.Linear(input_size, output_size),
  595. torch.nn.LayerNorm(output_size),
  596. torch.nn.Dropout(dropout_rate),
  597. torch.nn.ReLU(),
  598. pos_enc_class(output_size, positional_dropout_rate),
  599. )
  600. elif input_layer == "conv2d":
  601. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  602. elif input_layer == "conv2d2":
  603. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  604. elif input_layer == "conv2d6":
  605. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  606. elif input_layer == "conv2d8":
  607. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  608. elif input_layer == "embed":
  609. self.embed = torch.nn.Sequential(
  610. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  611. pos_enc_class(output_size, positional_dropout_rate),
  612. )
  613. elif input_layer is None:
  614. if input_size == output_size:
  615. self.embed = None
  616. else:
  617. self.embed = torch.nn.Linear(input_size, output_size)
  618. elif input_layer == "pe":
  619. self.embed = SinusoidalPositionEncoder()
  620. elif input_layer == "pe_online":
  621. self.embed = StreamSinusoidalPositionEncoder()
  622. else:
  623. raise ValueError("unknown input_layer: " + input_layer)
  624. self.normalize_before = normalize_before
  625. if positionwise_layer_type == "linear":
  626. positionwise_layer = PositionwiseFeedForward
  627. positionwise_layer_args = (
  628. output_size,
  629. linear_units,
  630. dropout_rate,
  631. )
  632. elif positionwise_layer_type == "conv1d":
  633. positionwise_layer = MultiLayeredConv1d
  634. positionwise_layer_args = (
  635. output_size,
  636. linear_units,
  637. positionwise_conv_kernel_size,
  638. dropout_rate,
  639. )
  640. elif positionwise_layer_type == "conv1d-linear":
  641. positionwise_layer = Conv1dLinear
  642. positionwise_layer_args = (
  643. output_size,
  644. linear_units,
  645. positionwise_conv_kernel_size,
  646. dropout_rate,
  647. )
  648. else:
  649. raise NotImplementedError("Support only linear or conv1d.")
  650. if selfattention_layer_type == "selfattn":
  651. encoder_selfattn_layer = MultiHeadedAttention
  652. encoder_selfattn_layer_args = (
  653. attention_heads,
  654. output_size,
  655. attention_dropout_rate,
  656. )
  657. elif selfattention_layer_type == "sanm":
  658. encoder_selfattn_layer = MultiHeadedAttentionSANM
  659. encoder_selfattn_layer_args0 = (
  660. attention_heads,
  661. input_size,
  662. output_size,
  663. attention_dropout_rate,
  664. kernel_size,
  665. sanm_shfit,
  666. )
  667. encoder_selfattn_layer_args = (
  668. attention_heads,
  669. output_size,
  670. output_size,
  671. attention_dropout_rate,
  672. kernel_size,
  673. sanm_shfit,
  674. )
  675. self.encoders0 = repeat(
  676. 1,
  677. lambda lnum: EncoderLayerSANM(
  678. input_size,
  679. output_size,
  680. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  681. positionwise_layer(*positionwise_layer_args),
  682. dropout_rate,
  683. normalize_before,
  684. concat_after,
  685. ),
  686. )
  687. self.encoders = repeat(
  688. num_blocks - 1,
  689. lambda lnum: EncoderLayerSANM(
  690. output_size,
  691. output_size,
  692. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  693. positionwise_layer(*positionwise_layer_args),
  694. dropout_rate,
  695. normalize_before,
  696. concat_after,
  697. ),
  698. )
  699. if self.normalize_before:
  700. self.after_norm = LayerNorm(output_size)
  701. self.interctc_layer_idx = interctc_layer_idx
  702. if len(interctc_layer_idx) > 0:
  703. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  704. self.interctc_use_conditioning = interctc_use_conditioning
  705. self.conditioning_layer = None
  706. shfit_fsmn = (kernel_size - 1) // 2
  707. self.overlap_chunk_cls = overlap_chunk(
  708. chunk_size=chunk_size,
  709. stride=stride,
  710. pad_left=pad_left,
  711. shfit_fsmn=shfit_fsmn,
  712. encoder_att_look_back_factor=encoder_att_look_back_factor,
  713. decoder_att_look_back_factor=decoder_att_look_back_factor,
  714. )
  715. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  716. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  717. def output_size(self) -> int:
  718. return self._output_size
  719. def forward(
  720. self,
  721. xs_pad: torch.Tensor,
  722. ilens: torch.Tensor,
  723. prev_states: torch.Tensor = None,
  724. ctc: CTC = None,
  725. ind: int = 0,
  726. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  727. """Embed positions in tensor.
  728. Args:
  729. xs_pad: input tensor (B, L, D)
  730. ilens: input length (B)
  731. prev_states: Not to be used now.
  732. Returns:
  733. position embedded tensor and mask
  734. """
  735. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  736. xs_pad *= self.output_size() ** 0.5
  737. if self.embed is None:
  738. xs_pad = xs_pad
  739. elif (
  740. isinstance(self.embed, Conv2dSubsampling)
  741. or isinstance(self.embed, Conv2dSubsampling2)
  742. or isinstance(self.embed, Conv2dSubsampling6)
  743. or isinstance(self.embed, Conv2dSubsampling8)
  744. ):
  745. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  746. if short_status:
  747. raise TooShortUttError(
  748. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  749. + f"(it needs more than {limit_size} frames), return empty results",
  750. xs_pad.size(1),
  751. limit_size,
  752. )
  753. xs_pad, masks = self.embed(xs_pad, masks)
  754. else:
  755. xs_pad = self.embed(xs_pad)
  756. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  757. if self.overlap_chunk_cls is not None:
  758. ilens = masks.squeeze(1).sum(1)
  759. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  760. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  761. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  762. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  763. dtype=xs_pad.dtype)
  764. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  765. xs_pad.size(0),
  766. dtype=xs_pad.dtype)
  767. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  768. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  769. intermediate_outs = []
  770. if len(self.interctc_layer_idx) == 0:
  771. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  772. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  773. else:
  774. for layer_idx, encoder_layer in enumerate(self.encoders):
  775. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  776. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  777. if layer_idx + 1 in self.interctc_layer_idx:
  778. encoder_out = xs_pad
  779. # intermediate outputs are also normalized
  780. if self.normalize_before:
  781. encoder_out = self.after_norm(encoder_out)
  782. intermediate_outs.append((layer_idx + 1, encoder_out))
  783. if self.interctc_use_conditioning:
  784. ctc_out = ctc.softmax(encoder_out)
  785. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  786. if self.normalize_before:
  787. xs_pad = self.after_norm(xs_pad)
  788. olens = masks.squeeze(1).sum(1)
  789. if len(intermediate_outs) > 0:
  790. return (xs_pad, intermediate_outs), olens, None
  791. return xs_pad, olens, None
  792. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  793. if len(cache) == 0:
  794. return feats
  795. cache["feats"] = to_device(cache["feats"], device=feats.device)
  796. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  797. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  798. return overlap_feats
  799. def forward_chunk(self,
  800. xs_pad: torch.Tensor,
  801. ilens: torch.Tensor,
  802. cache: dict = None,
  803. ):
  804. xs_pad *= self.output_size() ** 0.5
  805. if self.embed is None:
  806. xs_pad = xs_pad
  807. else:
  808. xs_pad = self.embed(xs_pad, cache)
  809. if cache["tail_chunk"]:
  810. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  811. else:
  812. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  813. if cache["opt"] is None:
  814. cache_layer_num = len(self.encoders0) + len(self.encoders)
  815. new_cache = [None] * cache_layer_num
  816. else:
  817. new_cache = cache["opt"]
  818. for layer_idx, encoder_layer in enumerate(self.encoders0):
  819. encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
  820. xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
  821. for layer_idx, encoder_layer in enumerate(self.encoders):
  822. encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
  823. xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
  824. if self.normalize_before:
  825. xs_pad = self.after_norm(xs_pad)
  826. if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
  827. cache["opt"] = new_cache
  828. return xs_pad, ilens, None
  829. def gen_tf2torch_map_dict(self):
  830. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  831. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  832. map_dict_local = {
  833. ## encoder
  834. # cicd
  835. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  836. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  837. "squeeze": None,
  838. "transpose": None,
  839. }, # (256,),(256,)
  840. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  841. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  842. "squeeze": None,
  843. "transpose": None,
  844. }, # (256,),(256,)
  845. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  846. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  847. "squeeze": 0,
  848. "transpose": (1, 0),
  849. }, # (768,256),(1,256,768)
  850. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  851. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  852. "squeeze": None,
  853. "transpose": None,
  854. }, # (768,),(768,)
  855. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  856. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  857. "squeeze": 0,
  858. "transpose": (1, 2, 0),
  859. }, # (256,1,31),(1,31,256,1)
  860. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  861. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  862. "squeeze": 0,
  863. "transpose": (1, 0),
  864. }, # (256,256),(1,256,256)
  865. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  866. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  867. "squeeze": None,
  868. "transpose": None,
  869. }, # (256,),(256,)
  870. # ffn
  871. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  872. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  873. "squeeze": None,
  874. "transpose": None,
  875. }, # (256,),(256,)
  876. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  877. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  878. "squeeze": None,
  879. "transpose": None,
  880. }, # (256,),(256,)
  881. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  882. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  883. "squeeze": 0,
  884. "transpose": (1, 0),
  885. }, # (1024,256),(1,256,1024)
  886. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  887. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  888. "squeeze": None,
  889. "transpose": None,
  890. }, # (1024,),(1024,)
  891. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  892. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  893. "squeeze": 0,
  894. "transpose": (1, 0),
  895. }, # (256,1024),(1,1024,256)
  896. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  897. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  898. "squeeze": None,
  899. "transpose": None,
  900. }, # (256,),(256,)
  901. # out norm
  902. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  903. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  904. "squeeze": None,
  905. "transpose": None,
  906. }, # (256,),(256,)
  907. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  908. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  909. "squeeze": None,
  910. "transpose": None,
  911. }, # (256,),(256,)
  912. }
  913. return map_dict_local
  914. def convert_tf2torch(self,
  915. var_dict_tf,
  916. var_dict_torch,
  917. ):
  918. map_dict = self.gen_tf2torch_map_dict()
  919. var_dict_torch_update = dict()
  920. for name in sorted(var_dict_torch.keys(), reverse=False):
  921. names = name.split('.')
  922. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  923. if names[1] == "encoders0":
  924. layeridx = int(names[2])
  925. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  926. name_q = name_q.replace("encoders0", "encoders")
  927. layeridx_bias = 0
  928. layeridx += layeridx_bias
  929. if name_q in map_dict.keys():
  930. name_v = map_dict[name_q]["name"]
  931. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  932. data_tf = var_dict_tf[name_tf]
  933. if map_dict[name_q]["squeeze"] is not None:
  934. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  935. if map_dict[name_q]["transpose"] is not None:
  936. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  937. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  938. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  939. var_dict_torch[
  940. name].size(),
  941. data_tf.size())
  942. var_dict_torch_update[name] = data_tf
  943. logging.info(
  944. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  945. var_dict_tf[name_tf].shape))
  946. elif names[1] == "encoders":
  947. layeridx = int(names[2])
  948. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  949. layeridx_bias = 1
  950. layeridx += layeridx_bias
  951. if name_q in map_dict.keys():
  952. name_v = map_dict[name_q]["name"]
  953. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  954. data_tf = var_dict_tf[name_tf]
  955. if map_dict[name_q]["squeeze"] is not None:
  956. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  957. if map_dict[name_q]["transpose"] is not None:
  958. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  959. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  960. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  961. var_dict_torch[
  962. name].size(),
  963. data_tf.size())
  964. var_dict_torch_update[name] = data_tf
  965. logging.info(
  966. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  967. var_dict_tf[name_tf].shape))
  968. elif names[1] == "after_norm":
  969. name_tf = map_dict[name]["name"]
  970. data_tf = var_dict_tf[name_tf]
  971. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  972. var_dict_torch_update[name] = data_tf
  973. logging.info(
  974. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  975. var_dict_tf[name_tf].shape))
  976. return var_dict_torch_update
  977. class SANMVadEncoder(AbsEncoder):
  978. """
  979. Author: Speech Lab of DAMO Academy, Alibaba Group
  980. """
  981. def __init__(
  982. self,
  983. input_size: int,
  984. output_size: int = 256,
  985. attention_heads: int = 4,
  986. linear_units: int = 2048,
  987. num_blocks: int = 6,
  988. dropout_rate: float = 0.1,
  989. positional_dropout_rate: float = 0.1,
  990. attention_dropout_rate: float = 0.0,
  991. input_layer: Optional[str] = "conv2d",
  992. pos_enc_class=SinusoidalPositionEncoder,
  993. normalize_before: bool = True,
  994. concat_after: bool = False,
  995. positionwise_layer_type: str = "linear",
  996. positionwise_conv_kernel_size: int = 1,
  997. padding_idx: int = -1,
  998. interctc_layer_idx: List[int] = [],
  999. interctc_use_conditioning: bool = False,
  1000. kernel_size : int = 11,
  1001. sanm_shfit : int = 0,
  1002. selfattention_layer_type: str = "sanm",
  1003. ):
  1004. super().__init__()
  1005. self._output_size = output_size
  1006. if input_layer == "linear":
  1007. self.embed = torch.nn.Sequential(
  1008. torch.nn.Linear(input_size, output_size),
  1009. torch.nn.LayerNorm(output_size),
  1010. torch.nn.Dropout(dropout_rate),
  1011. torch.nn.ReLU(),
  1012. pos_enc_class(output_size, positional_dropout_rate),
  1013. )
  1014. elif input_layer == "conv2d":
  1015. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  1016. elif input_layer == "conv2d2":
  1017. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  1018. elif input_layer == "conv2d6":
  1019. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  1020. elif input_layer == "conv2d8":
  1021. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  1022. elif input_layer == "embed":
  1023. self.embed = torch.nn.Sequential(
  1024. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  1025. SinusoidalPositionEncoder(),
  1026. )
  1027. elif input_layer is None:
  1028. if input_size == output_size:
  1029. self.embed = None
  1030. else:
  1031. self.embed = torch.nn.Linear(input_size, output_size)
  1032. elif input_layer == "pe":
  1033. self.embed = SinusoidalPositionEncoder()
  1034. else:
  1035. raise ValueError("unknown input_layer: " + input_layer)
  1036. self.normalize_before = normalize_before
  1037. if positionwise_layer_type == "linear":
  1038. positionwise_layer = PositionwiseFeedForward
  1039. positionwise_layer_args = (
  1040. output_size,
  1041. linear_units,
  1042. dropout_rate,
  1043. )
  1044. elif positionwise_layer_type == "conv1d":
  1045. positionwise_layer = MultiLayeredConv1d
  1046. positionwise_layer_args = (
  1047. output_size,
  1048. linear_units,
  1049. positionwise_conv_kernel_size,
  1050. dropout_rate,
  1051. )
  1052. elif positionwise_layer_type == "conv1d-linear":
  1053. positionwise_layer = Conv1dLinear
  1054. positionwise_layer_args = (
  1055. output_size,
  1056. linear_units,
  1057. positionwise_conv_kernel_size,
  1058. dropout_rate,
  1059. )
  1060. else:
  1061. raise NotImplementedError("Support only linear or conv1d.")
  1062. if selfattention_layer_type == "selfattn":
  1063. encoder_selfattn_layer = MultiHeadedAttention
  1064. encoder_selfattn_layer_args = (
  1065. attention_heads,
  1066. output_size,
  1067. attention_dropout_rate,
  1068. )
  1069. elif selfattention_layer_type == "sanm":
  1070. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  1071. encoder_selfattn_layer_args0 = (
  1072. attention_heads,
  1073. input_size,
  1074. output_size,
  1075. attention_dropout_rate,
  1076. kernel_size,
  1077. sanm_shfit,
  1078. )
  1079. encoder_selfattn_layer_args = (
  1080. attention_heads,
  1081. output_size,
  1082. output_size,
  1083. attention_dropout_rate,
  1084. kernel_size,
  1085. sanm_shfit,
  1086. )
  1087. self.encoders0 = repeat(
  1088. 1,
  1089. lambda lnum: EncoderLayerSANM(
  1090. input_size,
  1091. output_size,
  1092. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  1093. positionwise_layer(*positionwise_layer_args),
  1094. dropout_rate,
  1095. normalize_before,
  1096. concat_after,
  1097. ),
  1098. )
  1099. self.encoders = repeat(
  1100. num_blocks-1,
  1101. lambda lnum: EncoderLayerSANM(
  1102. output_size,
  1103. output_size,
  1104. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  1105. positionwise_layer(*positionwise_layer_args),
  1106. dropout_rate,
  1107. normalize_before,
  1108. concat_after,
  1109. ),
  1110. )
  1111. if self.normalize_before:
  1112. self.after_norm = LayerNorm(output_size)
  1113. self.interctc_layer_idx = interctc_layer_idx
  1114. if len(interctc_layer_idx) > 0:
  1115. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  1116. self.interctc_use_conditioning = interctc_use_conditioning
  1117. self.conditioning_layer = None
  1118. self.dropout = nn.Dropout(dropout_rate)
  1119. def output_size(self) -> int:
  1120. return self._output_size
  1121. def forward(
  1122. self,
  1123. xs_pad: torch.Tensor,
  1124. ilens: torch.Tensor,
  1125. vad_indexes: torch.Tensor,
  1126. prev_states: torch.Tensor = None,
  1127. ctc: CTC = None,
  1128. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  1129. """Embed positions in tensor.
  1130. Args:
  1131. xs_pad: input tensor (B, L, D)
  1132. ilens: input length (B)
  1133. prev_states: Not to be used now.
  1134. Returns:
  1135. position embedded tensor and mask
  1136. """
  1137. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  1138. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  1139. no_future_masks = masks & sub_masks
  1140. xs_pad *= self.output_size()**0.5
  1141. if self.embed is None:
  1142. xs_pad = xs_pad
  1143. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  1144. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  1145. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  1146. if short_status:
  1147. raise TooShortUttError(
  1148. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  1149. f"(it needs more than {limit_size} frames), return empty results",
  1150. xs_pad.size(1),
  1151. limit_size,
  1152. )
  1153. xs_pad, masks = self.embed(xs_pad, masks)
  1154. else:
  1155. xs_pad = self.embed(xs_pad)
  1156. # xs_pad = self.dropout(xs_pad)
  1157. mask_tup0 = [masks, no_future_masks]
  1158. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  1159. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  1160. intermediate_outs = []
  1161. for layer_idx, encoder_layer in enumerate(self.encoders):
  1162. if layer_idx + 1 == len(self.encoders):
  1163. # This is last layer.
  1164. coner_mask = torch.ones(masks.size(0),
  1165. masks.size(-1),
  1166. masks.size(-1),
  1167. device=xs_pad.device,
  1168. dtype=torch.bool)
  1169. for word_index, length in enumerate(ilens):
  1170. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  1171. vad_indexes[word_index],
  1172. device=xs_pad.device)
  1173. layer_mask = masks & coner_mask
  1174. else:
  1175. layer_mask = no_future_masks
  1176. mask_tup1 = [masks, layer_mask]
  1177. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  1178. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  1179. if self.normalize_before:
  1180. xs_pad = self.after_norm(xs_pad)
  1181. olens = masks.squeeze(1).sum(1)
  1182. if len(intermediate_outs) > 0:
  1183. return (xs_pad, intermediate_outs), olens, None
  1184. return xs_pad, olens, None