data2vec_encoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. import torch
  8. import torch.distributed as dist
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from typeguard import check_argument_types
  12. from funasr.models.encoder.abs_encoder import AbsEncoder
  13. from funasr.modules.data2vec.data_utils import compute_mask_indices
  14. from funasr.modules.data2vec.ema_module import EMAModule
  15. from funasr.modules.data2vec.grad_multiply import GradMultiply
  16. from funasr.modules.data2vec.wav2vec2 import (
  17. ConvFeatureExtractionModel,
  18. TransformerEncoder,
  19. )
  20. from funasr.modules.nets_utils import make_pad_mask
  21. def get_annealed_rate(start, end, curr_step, total_steps):
  22. r = end - start
  23. pct_remaining = 1 - curr_step / total_steps
  24. return end - r * pct_remaining
  25. class Data2VecEncoder(AbsEncoder):
  26. def __init__(
  27. self,
  28. # for ConvFeatureExtractionModel
  29. input_size: int = None,
  30. extractor_mode: str = None,
  31. conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
  32. # for Transformer Encoder
  33. ## model architecture
  34. layer_type: str = "transformer",
  35. layer_norm_first: bool = False,
  36. encoder_layers: int = 12,
  37. encoder_embed_dim: int = 768,
  38. encoder_ffn_embed_dim: int = 3072,
  39. encoder_attention_heads: int = 12,
  40. activation_fn: str = "gelu",
  41. ## dropouts
  42. dropout: float = 0.1,
  43. attention_dropout: float = 0.1,
  44. activation_dropout: float = 0.0,
  45. encoder_layerdrop: float = 0.0,
  46. dropout_input: float = 0.0,
  47. dropout_features: float = 0.0,
  48. ## grad settings
  49. feature_grad_mult: float = 1.0,
  50. ## masking
  51. mask_prob: float = 0.65,
  52. mask_length: int = 10,
  53. mask_selection: str = "static",
  54. mask_other: int = 0,
  55. no_mask_overlap: bool = False,
  56. mask_min_space: int = 1,
  57. require_same_masks: bool = True, # if set as True, collate_fn should be clipping
  58. mask_dropout: float = 0.0,
  59. ## channel masking
  60. mask_channel_length: int = 10,
  61. mask_channel_prob: float = 0.0,
  62. mask_channel_before: bool = False,
  63. mask_channel_selection: str = "static",
  64. mask_channel_other: int = 0,
  65. no_mask_channel_overlap: bool = False,
  66. mask_channel_min_space: int = 1,
  67. ## positional embeddings
  68. conv_pos: int = 128,
  69. conv_pos_groups: int = 16,
  70. pos_conv_depth: int = 1,
  71. max_positions: int = 100000,
  72. # EMA module
  73. average_top_k_layers: int = 8,
  74. layer_norm_target_layer: bool = False,
  75. instance_norm_target_layer: bool = False,
  76. instance_norm_targets: bool = False,
  77. layer_norm_targets: bool = False,
  78. batch_norm_target_layer: bool = False,
  79. group_norm_target_layer: bool = False,
  80. ema_decay: float = 0.999,
  81. ema_end_decay: float = 0.9999,
  82. ema_anneal_end_step: int = 100000,
  83. ema_transformer_only: bool = True,
  84. ema_layers_only: bool = True,
  85. min_target_var: float = 0.1,
  86. min_pred_var: float = 0.01,
  87. # Loss
  88. loss_beta: float = 0.0,
  89. loss_scale: float = None,
  90. # FP16 optimization
  91. required_seq_len_multiple: int = 2,
  92. ):
  93. assert check_argument_types()
  94. super().__init__()
  95. # ConvFeatureExtractionModel
  96. self.conv_feature_layers = conv_feature_layers
  97. feature_enc_layers = eval(conv_feature_layers)
  98. self.extractor_embed = feature_enc_layers[-1][0]
  99. self.feature_extractor = ConvFeatureExtractionModel(
  100. conv_layers=feature_enc_layers,
  101. dropout=0.0,
  102. mode=extractor_mode,
  103. in_d=input_size,
  104. )
  105. # Transformer Encoder
  106. ## model architecture
  107. self.layer_type = layer_type
  108. self.layer_norm_first = layer_norm_first
  109. self.encoder_layers = encoder_layers
  110. self.encoder_embed_dim = encoder_embed_dim
  111. self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
  112. self.encoder_attention_heads = encoder_attention_heads
  113. self.activation_fn = activation_fn
  114. ## dropout
  115. self.dropout = dropout
  116. self.attention_dropout = attention_dropout
  117. self.activation_dropout = activation_dropout
  118. self.encoder_layerdrop = encoder_layerdrop
  119. self.dropout_input = dropout_input
  120. self.dropout_features = dropout_features
  121. ## grad settings
  122. self.feature_grad_mult = feature_grad_mult
  123. ## masking
  124. self.mask_prob = mask_prob
  125. self.mask_length = mask_length
  126. self.mask_selection = mask_selection
  127. self.mask_other = mask_other
  128. self.no_mask_overlap = no_mask_overlap
  129. self.mask_min_space = mask_min_space
  130. self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
  131. self.mask_dropout = mask_dropout
  132. ## channel masking
  133. self.mask_channel_length = mask_channel_length
  134. self.mask_channel_prob = mask_channel_prob
  135. self.mask_channel_before = mask_channel_before
  136. self.mask_channel_selection = mask_channel_selection
  137. self.mask_channel_other = mask_channel_other
  138. self.no_mask_channel_overlap = no_mask_channel_overlap
  139. self.mask_channel_min_space = mask_channel_min_space
  140. ## positional embeddings
  141. self.conv_pos = conv_pos
  142. self.conv_pos_groups = conv_pos_groups
  143. self.pos_conv_depth = pos_conv_depth
  144. self.max_positions = max_positions
  145. self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
  146. self.encoder = TransformerEncoder(
  147. dropout=self.dropout,
  148. encoder_embed_dim=self.encoder_embed_dim,
  149. required_seq_len_multiple=required_seq_len_multiple,
  150. pos_conv_depth=self.pos_conv_depth,
  151. conv_pos=self.conv_pos,
  152. conv_pos_groups=self.conv_pos_groups,
  153. # transformer layers
  154. layer_type=self.layer_type,
  155. encoder_layers=self.encoder_layers,
  156. encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
  157. encoder_attention_heads=self.encoder_attention_heads,
  158. attention_dropout=self.attention_dropout,
  159. activation_dropout=self.activation_dropout,
  160. activation_fn=self.activation_fn,
  161. layer_norm_first=self.layer_norm_first,
  162. encoder_layerdrop=self.encoder_layerdrop,
  163. max_positions=self.max_positions,
  164. )
  165. ## projections and dropouts
  166. self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
  167. self.dropout_input = nn.Dropout(self.dropout_input)
  168. self.dropout_features = nn.Dropout(self.dropout_features)
  169. self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
  170. self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
  171. # EMA module
  172. self.average_top_k_layers = average_top_k_layers
  173. self.layer_norm_target_layer = layer_norm_target_layer
  174. self.instance_norm_target_layer = instance_norm_target_layer
  175. self.instance_norm_targets = instance_norm_targets
  176. self.layer_norm_targets = layer_norm_targets
  177. self.batch_norm_target_layer = batch_norm_target_layer
  178. self.group_norm_target_layer = group_norm_target_layer
  179. self.ema_decay = ema_decay
  180. self.ema_end_decay = ema_end_decay
  181. self.ema_anneal_end_step = ema_anneal_end_step
  182. self.ema_transformer_only = ema_transformer_only
  183. self.ema_layers_only = ema_layers_only
  184. self.min_target_var = min_target_var
  185. self.min_pred_var = min_pred_var
  186. self.ema = None
  187. # Loss
  188. self.loss_beta = loss_beta
  189. self.loss_scale = loss_scale
  190. # FP16 optimization
  191. self.required_seq_len_multiple = required_seq_len_multiple
  192. self.num_updates = 0
  193. logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
  194. def make_ema_teacher(self):
  195. skip_keys = set()
  196. if self.ema_layers_only:
  197. self.ema_transformer_only = True
  198. for k, _ in self.encoder.pos_conv.named_parameters():
  199. skip_keys.add(f"pos_conv.{k}")
  200. self.ema = EMAModule(
  201. self.encoder if self.ema_transformer_only else self,
  202. ema_decay=self.ema_decay,
  203. ema_fp32=True,
  204. skip_keys=skip_keys,
  205. )
  206. def set_num_updates(self, num_updates):
  207. if self.ema is None and self.final_proj is not None:
  208. logging.info("Making EMA Teacher")
  209. self.make_ema_teacher()
  210. elif self.training and self.ema is not None:
  211. if self.ema_decay != self.ema_end_decay:
  212. if num_updates >= self.ema_anneal_end_step:
  213. decay = self.ema_end_decay
  214. else:
  215. decay = get_annealed_rate(
  216. self.ema_decay,
  217. self.ema_end_decay,
  218. num_updates,
  219. self.ema_anneal_end_step,
  220. )
  221. self.ema.set_decay(decay)
  222. if self.ema.get_decay() < 1:
  223. self.ema.step(self.encoder if self.ema_transformer_only else self)
  224. self.num_updates = num_updates
  225. def apply_mask(
  226. self,
  227. x,
  228. padding_mask,
  229. mask_indices=None,
  230. mask_channel_indices=None,
  231. ):
  232. B, T, C = x.shape
  233. if self.mask_channel_prob > 0 and self.mask_channel_before:
  234. mask_channel_indices = compute_mask_indices(
  235. (B, C),
  236. None,
  237. self.mask_channel_prob,
  238. self.mask_channel_length,
  239. self.mask_channel_selection,
  240. self.mask_channel_other,
  241. no_overlap=self.no_mask_channel_overlap,
  242. min_space=self.mask_channel_min_space,
  243. )
  244. mask_channel_indices = (
  245. torch.from_numpy(mask_channel_indices)
  246. .to(x.device)
  247. .unsqueeze(1)
  248. .expand(-1, T, -1)
  249. )
  250. x[mask_channel_indices] = 0
  251. if self.mask_prob > 0:
  252. if mask_indices is None:
  253. mask_indices = compute_mask_indices(
  254. (B, T),
  255. padding_mask,
  256. self.mask_prob,
  257. self.mask_length,
  258. self.mask_selection,
  259. self.mask_other,
  260. min_masks=1,
  261. no_overlap=self.no_mask_overlap,
  262. min_space=self.mask_min_space,
  263. require_same_masks=self.require_same_masks,
  264. mask_dropout=self.mask_dropout,
  265. )
  266. mask_indices = torch.from_numpy(mask_indices).to(x.device)
  267. x[mask_indices] = self.mask_emb
  268. else:
  269. mask_indices = None
  270. if self.mask_channel_prob > 0 and not self.mask_channel_before:
  271. if mask_channel_indices is None:
  272. mask_channel_indices = compute_mask_indices(
  273. (B, C),
  274. None,
  275. self.mask_channel_prob,
  276. self.mask_channel_length,
  277. self.mask_channel_selection,
  278. self.mask_channel_other,
  279. no_overlap=self.no_mask_channel_overlap,
  280. min_space=self.mask_channel_min_space,
  281. )
  282. mask_channel_indices = (
  283. torch.from_numpy(mask_channel_indices)
  284. .to(x.device)
  285. .unsqueeze(1)
  286. .expand(-1, T, -1)
  287. )
  288. x[mask_channel_indices] = 0
  289. return x, mask_indices
  290. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  291. """
  292. Computes the output length of the convolutional layers
  293. """
  294. def _conv_out_length(input_length, kernel_size, stride):
  295. return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
  296. conv_cfg_list = eval(self.conv_feature_layers)
  297. for i in range(len(conv_cfg_list)):
  298. input_lengths = _conv_out_length(
  299. input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
  300. )
  301. return input_lengths.to(torch.long)
  302. def forward(
  303. self,
  304. xs_pad,
  305. ilens=None,
  306. mask=False,
  307. features_only=True,
  308. layer=None,
  309. mask_indices=None,
  310. mask_channel_indices=None,
  311. padding_count=None,
  312. ):
  313. # create padding_mask by ilens
  314. if ilens is not None:
  315. padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
  316. else:
  317. padding_mask = None
  318. features = xs_pad
  319. if self.feature_grad_mult > 0:
  320. features = self.feature_extractor(features)
  321. if self.feature_grad_mult != 1.0:
  322. features = GradMultiply.apply(features, self.feature_grad_mult)
  323. else:
  324. with torch.no_grad():
  325. features = self.feature_extractor(features)
  326. features = features.transpose(1, 2)
  327. features = self.layer_norm(features)
  328. orig_padding_mask = padding_mask
  329. if padding_mask is not None:
  330. input_lengths = (1 - padding_mask.long()).sum(-1)
  331. # apply conv formula to get real output_lengths
  332. output_lengths = self._get_feat_extract_output_lengths(input_lengths)
  333. padding_mask = torch.zeros(
  334. features.shape[:2], dtype=features.dtype, device=features.device
  335. )
  336. # these two operations makes sure that all values
  337. # before the output lengths indices are attended to
  338. padding_mask[
  339. (
  340. torch.arange(padding_mask.shape[0], device=padding_mask.device),
  341. output_lengths - 1,
  342. )
  343. ] = 1
  344. padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
  345. else:
  346. padding_mask = None
  347. if self.post_extract_proj is not None:
  348. features = self.post_extract_proj(features)
  349. pre_encoder_features = None
  350. if self.ema_transformer_only:
  351. pre_encoder_features = features.clone()
  352. features = self.dropout_input(features)
  353. if mask:
  354. x, mask_indices = self.apply_mask(
  355. features,
  356. padding_mask,
  357. mask_indices=mask_indices,
  358. mask_channel_indices=mask_channel_indices,
  359. )
  360. else:
  361. x = features
  362. mask_indices = None
  363. x, layer_results = self.encoder(
  364. x,
  365. padding_mask=padding_mask,
  366. layer=layer,
  367. )
  368. if features_only:
  369. encoder_out_lens = (1 - padding_mask.long()).sum(1)
  370. return x, encoder_out_lens, None
  371. result = {
  372. "losses": {},
  373. "padding_mask": padding_mask,
  374. "x": x,
  375. }
  376. with torch.no_grad():
  377. self.ema.model.eval()
  378. if self.ema_transformer_only:
  379. y, layer_results = self.ema.model.extract_features(
  380. pre_encoder_features,
  381. padding_mask=padding_mask,
  382. min_layer=self.encoder_layers - self.average_top_k_layers,
  383. )
  384. y = {
  385. "x": y,
  386. "padding_mask": padding_mask,
  387. "layer_results": layer_results,
  388. }
  389. else:
  390. y = self.ema.model.extract_features(
  391. source=xs_pad,
  392. padding_mask=orig_padding_mask,
  393. mask=False,
  394. )
  395. target_layer_results = [l[2] for l in y["layer_results"]]
  396. permuted = False
  397. if self.instance_norm_target_layer or self.batch_norm_target_layer:
  398. target_layer_results = [
  399. tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
  400. ]
  401. permuted = True
  402. if self.batch_norm_target_layer:
  403. target_layer_results = [
  404. F.batch_norm(
  405. tl.float(), running_mean=None, running_var=None, training=True
  406. )
  407. for tl in target_layer_results
  408. ]
  409. if self.instance_norm_target_layer:
  410. target_layer_results = [
  411. F.instance_norm(tl.float()) for tl in target_layer_results
  412. ]
  413. if permuted:
  414. target_layer_results = [
  415. tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
  416. ]
  417. if self.group_norm_target_layer:
  418. target_layer_results = [
  419. F.layer_norm(tl.float(), tl.shape[-2:])
  420. for tl in target_layer_results
  421. ]
  422. if self.layer_norm_target_layer:
  423. target_layer_results = [
  424. F.layer_norm(tl.float(), tl.shape[-1:])
  425. for tl in target_layer_results
  426. ]
  427. y = sum(target_layer_results) / len(target_layer_results)
  428. if self.layer_norm_targets:
  429. y = F.layer_norm(y.float(), y.shape[-1:])
  430. if self.instance_norm_targets:
  431. y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
  432. if not permuted:
  433. y = y.transpose(0, 1)
  434. y = y[mask_indices]
  435. x = x[mask_indices]
  436. x = self.final_proj(x)
  437. sz = x.size(-1)
  438. if self.loss_beta == 0:
  439. loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
  440. else:
  441. loss = F.smooth_l1_loss(
  442. x.float(), y.float(), reduction="none", beta=self.loss_beta
  443. ).sum(dim=-1)
  444. if self.loss_scale is not None:
  445. scale = self.loss_scale
  446. else:
  447. scale = 1 / math.sqrt(sz)
  448. result["losses"]["regression"] = loss.sum() * scale
  449. if "sample_size" not in result:
  450. result["sample_size"] = loss.numel()
  451. with torch.no_grad():
  452. result["target_var"] = self.compute_var(y)
  453. result["pred_var"] = self.compute_var(x.float())
  454. if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
  455. logging.error(
  456. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  457. )
  458. raise Exception(
  459. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  460. )
  461. if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
  462. logging.error(
  463. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  464. )
  465. raise Exception(
  466. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  467. )
  468. if self.ema is not None:
  469. result["ema_decay"] = self.ema.get_decay() * 1000
  470. return result
  471. @staticmethod
  472. def compute_var(y):
  473. y = y.view(-1, y.size(-1))
  474. if dist.is_initialized():
  475. zc = torch.tensor(y.size(0)).cuda()
  476. zs = y.sum(dim=0)
  477. zss = (y ** 2).sum(dim=0)
  478. dist.all_reduce(zc)
  479. dist.all_reduce(zs)
  480. dist.all_reduce(zss)
  481. var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
  482. return torch.sqrt(var + 1e-6).mean()
  483. else:
  484. return torch.sqrt(y.var(dim=0) + 1e-6).mean()
  485. def extract_features(
  486. self, xs_pad, ilens, mask=False, layer=None
  487. ):
  488. res = self.forward(
  489. xs_pad,
  490. ilens,
  491. mask=mask,
  492. features_only=True,
  493. layer=layer,
  494. )
  495. return res
  496. def remove_pretraining_modules(self, last_layer=None):
  497. self.final_proj = None
  498. self.ema = None
  499. if last_layer is not None:
  500. self.encoder.layers = nn.ModuleList(
  501. l for i, l in enumerate(self.encoder.layers) if i <= last_layer
  502. )
  503. def output_size(self) -> int:
  504. return self.encoder_embed_dim