data2vec_encoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. import torch
  8. import torch.distributed as dist
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from funasr.models.data2vec.data_utils import compute_mask_indices
  12. from funasr.models.data2vec.ema_module import EMAModule
  13. from funasr.models.data2vec.grad_multiply import GradMultiply
  14. from funasr.models.data2vec.wav2vec2 import (
  15. ConvFeatureExtractionModel,
  16. TransformerEncoder,
  17. )
  18. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  19. def get_annealed_rate(start, end, curr_step, total_steps):
  20. r = end - start
  21. pct_remaining = 1 - curr_step / total_steps
  22. return end - r * pct_remaining
  23. class Data2VecEncoder(nn.Module):
  24. def __init__(
  25. self,
  26. # for ConvFeatureExtractionModel
  27. input_size: int = None,
  28. extractor_mode: str = None,
  29. conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
  30. # for Transformer Encoder
  31. ## model architecture
  32. layer_type: str = "transformer",
  33. layer_norm_first: bool = False,
  34. encoder_layers: int = 12,
  35. encoder_embed_dim: int = 768,
  36. encoder_ffn_embed_dim: int = 3072,
  37. encoder_attention_heads: int = 12,
  38. activation_fn: str = "gelu",
  39. ## dropouts
  40. dropout: float = 0.1,
  41. attention_dropout: float = 0.1,
  42. activation_dropout: float = 0.0,
  43. encoder_layerdrop: float = 0.0,
  44. dropout_input: float = 0.0,
  45. dropout_features: float = 0.0,
  46. ## grad settings
  47. feature_grad_mult: float = 1.0,
  48. ## masking
  49. mask_prob: float = 0.65,
  50. mask_length: int = 10,
  51. mask_selection: str = "static",
  52. mask_other: int = 0,
  53. no_mask_overlap: bool = False,
  54. mask_min_space: int = 1,
  55. require_same_masks: bool = True, # if set as True, collate_fn should be clipping
  56. mask_dropout: float = 0.0,
  57. ## channel masking
  58. mask_channel_length: int = 10,
  59. mask_channel_prob: float = 0.0,
  60. mask_channel_before: bool = False,
  61. mask_channel_selection: str = "static",
  62. mask_channel_other: int = 0,
  63. no_mask_channel_overlap: bool = False,
  64. mask_channel_min_space: int = 1,
  65. ## positional embeddings
  66. conv_pos: int = 128,
  67. conv_pos_groups: int = 16,
  68. pos_conv_depth: int = 1,
  69. max_positions: int = 100000,
  70. # EMA module
  71. average_top_k_layers: int = 8,
  72. layer_norm_target_layer: bool = False,
  73. instance_norm_target_layer: bool = False,
  74. instance_norm_targets: bool = False,
  75. layer_norm_targets: bool = False,
  76. batch_norm_target_layer: bool = False,
  77. group_norm_target_layer: bool = False,
  78. ema_decay: float = 0.999,
  79. ema_end_decay: float = 0.9999,
  80. ema_anneal_end_step: int = 100000,
  81. ema_transformer_only: bool = True,
  82. ema_layers_only: bool = True,
  83. min_target_var: float = 0.1,
  84. min_pred_var: float = 0.01,
  85. # Loss
  86. loss_beta: float = 0.0,
  87. loss_scale: float = None,
  88. # FP16 optimization
  89. required_seq_len_multiple: int = 2,
  90. ):
  91. super().__init__()
  92. # ConvFeatureExtractionModel
  93. self.conv_feature_layers = conv_feature_layers
  94. feature_enc_layers = eval(conv_feature_layers)
  95. self.extractor_embed = feature_enc_layers[-1][0]
  96. self.feature_extractor = ConvFeatureExtractionModel(
  97. conv_layers=feature_enc_layers,
  98. dropout=0.0,
  99. mode=extractor_mode,
  100. in_d=input_size,
  101. )
  102. # Transformer Encoder
  103. ## model architecture
  104. self.layer_type = layer_type
  105. self.layer_norm_first = layer_norm_first
  106. self.encoder_layers = encoder_layers
  107. self.encoder_embed_dim = encoder_embed_dim
  108. self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
  109. self.encoder_attention_heads = encoder_attention_heads
  110. self.activation_fn = activation_fn
  111. ## dropout
  112. self.dropout = dropout
  113. self.attention_dropout = attention_dropout
  114. self.activation_dropout = activation_dropout
  115. self.encoder_layerdrop = encoder_layerdrop
  116. self.dropout_input = dropout_input
  117. self.dropout_features = dropout_features
  118. ## grad settings
  119. self.feature_grad_mult = feature_grad_mult
  120. ## masking
  121. self.mask_prob = mask_prob
  122. self.mask_length = mask_length
  123. self.mask_selection = mask_selection
  124. self.mask_other = mask_other
  125. self.no_mask_overlap = no_mask_overlap
  126. self.mask_min_space = mask_min_space
  127. self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
  128. self.mask_dropout = mask_dropout
  129. ## channel masking
  130. self.mask_channel_length = mask_channel_length
  131. self.mask_channel_prob = mask_channel_prob
  132. self.mask_channel_before = mask_channel_before
  133. self.mask_channel_selection = mask_channel_selection
  134. self.mask_channel_other = mask_channel_other
  135. self.no_mask_channel_overlap = no_mask_channel_overlap
  136. self.mask_channel_min_space = mask_channel_min_space
  137. ## positional embeddings
  138. self.conv_pos = conv_pos
  139. self.conv_pos_groups = conv_pos_groups
  140. self.pos_conv_depth = pos_conv_depth
  141. self.max_positions = max_positions
  142. self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
  143. self.encoder = TransformerEncoder(
  144. dropout=self.dropout,
  145. encoder_embed_dim=self.encoder_embed_dim,
  146. required_seq_len_multiple=required_seq_len_multiple,
  147. pos_conv_depth=self.pos_conv_depth,
  148. conv_pos=self.conv_pos,
  149. conv_pos_groups=self.conv_pos_groups,
  150. # transformer layers
  151. layer_type=self.layer_type,
  152. encoder_layers=self.encoder_layers,
  153. encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
  154. encoder_attention_heads=self.encoder_attention_heads,
  155. attention_dropout=self.attention_dropout,
  156. activation_dropout=self.activation_dropout,
  157. activation_fn=self.activation_fn,
  158. layer_norm_first=self.layer_norm_first,
  159. encoder_layerdrop=self.encoder_layerdrop,
  160. max_positions=self.max_positions,
  161. )
  162. ## projections and dropouts
  163. self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
  164. self.dropout_input = nn.Dropout(self.dropout_input)
  165. self.dropout_features = nn.Dropout(self.dropout_features)
  166. self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
  167. self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
  168. # EMA module
  169. self.average_top_k_layers = average_top_k_layers
  170. self.layer_norm_target_layer = layer_norm_target_layer
  171. self.instance_norm_target_layer = instance_norm_target_layer
  172. self.instance_norm_targets = instance_norm_targets
  173. self.layer_norm_targets = layer_norm_targets
  174. self.batch_norm_target_layer = batch_norm_target_layer
  175. self.group_norm_target_layer = group_norm_target_layer
  176. self.ema_decay = ema_decay
  177. self.ema_end_decay = ema_end_decay
  178. self.ema_anneal_end_step = ema_anneal_end_step
  179. self.ema_transformer_only = ema_transformer_only
  180. self.ema_layers_only = ema_layers_only
  181. self.min_target_var = min_target_var
  182. self.min_pred_var = min_pred_var
  183. self.ema = None
  184. # Loss
  185. self.loss_beta = loss_beta
  186. self.loss_scale = loss_scale
  187. # FP16 optimization
  188. self.required_seq_len_multiple = required_seq_len_multiple
  189. self.num_updates = 0
  190. logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
  191. def make_ema_teacher(self):
  192. skip_keys = set()
  193. if self.ema_layers_only:
  194. self.ema_transformer_only = True
  195. for k, _ in self.encoder.pos_conv.named_parameters():
  196. skip_keys.add(f"pos_conv.{k}")
  197. self.ema = EMAModule(
  198. self.encoder if self.ema_transformer_only else self,
  199. ema_decay=self.ema_decay,
  200. ema_fp32=True,
  201. skip_keys=skip_keys,
  202. )
  203. def set_num_updates(self, num_updates):
  204. if self.ema is None and self.final_proj is not None:
  205. logging.info("Making EMA Teacher")
  206. self.make_ema_teacher()
  207. elif self.training and self.ema is not None:
  208. if self.ema_decay != self.ema_end_decay:
  209. if num_updates >= self.ema_anneal_end_step:
  210. decay = self.ema_end_decay
  211. else:
  212. decay = get_annealed_rate(
  213. self.ema_decay,
  214. self.ema_end_decay,
  215. num_updates,
  216. self.ema_anneal_end_step,
  217. )
  218. self.ema.set_decay(decay)
  219. if self.ema.get_decay() < 1:
  220. self.ema.step(self.encoder if self.ema_transformer_only else self)
  221. self.num_updates = num_updates
  222. def apply_mask(
  223. self,
  224. x,
  225. padding_mask,
  226. mask_indices=None,
  227. mask_channel_indices=None,
  228. ):
  229. B, T, C = x.shape
  230. if self.mask_channel_prob > 0 and self.mask_channel_before:
  231. mask_channel_indices = compute_mask_indices(
  232. (B, C),
  233. None,
  234. self.mask_channel_prob,
  235. self.mask_channel_length,
  236. self.mask_channel_selection,
  237. self.mask_channel_other,
  238. no_overlap=self.no_mask_channel_overlap,
  239. min_space=self.mask_channel_min_space,
  240. )
  241. mask_channel_indices = (
  242. torch.from_numpy(mask_channel_indices)
  243. .to(x.device)
  244. .unsqueeze(1)
  245. .expand(-1, T, -1)
  246. )
  247. x[mask_channel_indices] = 0
  248. if self.mask_prob > 0:
  249. if mask_indices is None:
  250. mask_indices = compute_mask_indices(
  251. (B, T),
  252. padding_mask,
  253. self.mask_prob,
  254. self.mask_length,
  255. self.mask_selection,
  256. self.mask_other,
  257. min_masks=1,
  258. no_overlap=self.no_mask_overlap,
  259. min_space=self.mask_min_space,
  260. require_same_masks=self.require_same_masks,
  261. mask_dropout=self.mask_dropout,
  262. )
  263. mask_indices = torch.from_numpy(mask_indices).to(x.device)
  264. x[mask_indices] = self.mask_emb
  265. else:
  266. mask_indices = None
  267. if self.mask_channel_prob > 0 and not self.mask_channel_before:
  268. if mask_channel_indices is None:
  269. mask_channel_indices = compute_mask_indices(
  270. (B, C),
  271. None,
  272. self.mask_channel_prob,
  273. self.mask_channel_length,
  274. self.mask_channel_selection,
  275. self.mask_channel_other,
  276. no_overlap=self.no_mask_channel_overlap,
  277. min_space=self.mask_channel_min_space,
  278. )
  279. mask_channel_indices = (
  280. torch.from_numpy(mask_channel_indices)
  281. .to(x.device)
  282. .unsqueeze(1)
  283. .expand(-1, T, -1)
  284. )
  285. x[mask_channel_indices] = 0
  286. return x, mask_indices
  287. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  288. """
  289. Computes the output length of the convolutional layers
  290. """
  291. def _conv_out_length(input_length, kernel_size, stride):
  292. return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
  293. conv_cfg_list = eval(self.conv_feature_layers)
  294. for i in range(len(conv_cfg_list)):
  295. input_lengths = _conv_out_length(
  296. input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
  297. )
  298. return input_lengths.to(torch.long)
  299. def forward(
  300. self,
  301. xs_pad,
  302. ilens=None,
  303. mask=False,
  304. features_only=True,
  305. layer=None,
  306. mask_indices=None,
  307. mask_channel_indices=None,
  308. padding_count=None,
  309. ):
  310. # create padding_mask by ilens
  311. if ilens is not None:
  312. padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
  313. else:
  314. padding_mask = None
  315. features = xs_pad
  316. if self.feature_grad_mult > 0:
  317. features = self.feature_extractor(features)
  318. if self.feature_grad_mult != 1.0:
  319. features = GradMultiply.apply(features, self.feature_grad_mult)
  320. else:
  321. with torch.no_grad():
  322. features = self.feature_extractor(features)
  323. features = features.transpose(1, 2)
  324. features = self.layer_norm(features)
  325. orig_padding_mask = padding_mask
  326. if padding_mask is not None:
  327. input_lengths = (1 - padding_mask.long()).sum(-1)
  328. # apply conv formula to get real output_lengths
  329. output_lengths = self._get_feat_extract_output_lengths(input_lengths)
  330. padding_mask = torch.zeros(
  331. features.shape[:2], dtype=features.dtype, device=features.device
  332. )
  333. # these two operations makes sure that all values
  334. # before the output lengths indices are attended to
  335. padding_mask[
  336. (
  337. torch.arange(padding_mask.shape[0], device=padding_mask.device),
  338. output_lengths - 1,
  339. )
  340. ] = 1
  341. padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
  342. else:
  343. padding_mask = None
  344. if self.post_extract_proj is not None:
  345. features = self.post_extract_proj(features)
  346. pre_encoder_features = None
  347. if self.ema_transformer_only:
  348. pre_encoder_features = features.clone()
  349. features = self.dropout_input(features)
  350. if mask:
  351. x, mask_indices = self.apply_mask(
  352. features,
  353. padding_mask,
  354. mask_indices=mask_indices,
  355. mask_channel_indices=mask_channel_indices,
  356. )
  357. else:
  358. x = features
  359. mask_indices = None
  360. x, layer_results = self.encoder(
  361. x,
  362. padding_mask=padding_mask,
  363. layer=layer,
  364. )
  365. if features_only:
  366. encoder_out_lens = (1 - padding_mask.long()).sum(1)
  367. return x, encoder_out_lens, None
  368. result = {
  369. "losses": {},
  370. "padding_mask": padding_mask,
  371. "x": x,
  372. }
  373. with torch.no_grad():
  374. self.ema.model.eval()
  375. if self.ema_transformer_only:
  376. y, layer_results = self.ema.model.extract_features(
  377. pre_encoder_features,
  378. padding_mask=padding_mask,
  379. min_layer=self.encoder_layers - self.average_top_k_layers,
  380. )
  381. y = {
  382. "x": y,
  383. "padding_mask": padding_mask,
  384. "layer_results": layer_results,
  385. }
  386. else:
  387. y = self.ema.model.extract_features(
  388. source=xs_pad,
  389. padding_mask=orig_padding_mask,
  390. mask=False,
  391. )
  392. target_layer_results = [l[2] for l in y["layer_results"]]
  393. permuted = False
  394. if self.instance_norm_target_layer or self.batch_norm_target_layer:
  395. target_layer_results = [
  396. tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
  397. ]
  398. permuted = True
  399. if self.batch_norm_target_layer:
  400. target_layer_results = [
  401. F.batch_norm(
  402. tl.float(), running_mean=None, running_var=None, training=True
  403. )
  404. for tl in target_layer_results
  405. ]
  406. if self.instance_norm_target_layer:
  407. target_layer_results = [
  408. F.instance_norm(tl.float()) for tl in target_layer_results
  409. ]
  410. if permuted:
  411. target_layer_results = [
  412. tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
  413. ]
  414. if self.group_norm_target_layer:
  415. target_layer_results = [
  416. F.layer_norm(tl.float(), tl.shape[-2:])
  417. for tl in target_layer_results
  418. ]
  419. if self.layer_norm_target_layer:
  420. target_layer_results = [
  421. F.layer_norm(tl.float(), tl.shape[-1:])
  422. for tl in target_layer_results
  423. ]
  424. y = sum(target_layer_results) / len(target_layer_results)
  425. if self.layer_norm_targets:
  426. y = F.layer_norm(y.float(), y.shape[-1:])
  427. if self.instance_norm_targets:
  428. y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
  429. if not permuted:
  430. y = y.transpose(0, 1)
  431. y = y[mask_indices]
  432. x = x[mask_indices]
  433. x = self.final_proj(x)
  434. sz = x.size(-1)
  435. if self.loss_beta == 0:
  436. loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
  437. else:
  438. loss = F.smooth_l1_loss(
  439. x.float(), y.float(), reduction="none", beta=self.loss_beta
  440. ).sum(dim=-1)
  441. if self.loss_scale is not None:
  442. scale = self.loss_scale
  443. else:
  444. scale = 1 / math.sqrt(sz)
  445. result["losses"]["regression"] = loss.sum() * scale
  446. if "sample_size" not in result:
  447. result["sample_size"] = loss.numel()
  448. with torch.no_grad():
  449. result["target_var"] = self.compute_var(y)
  450. result["pred_var"] = self.compute_var(x.float())
  451. if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
  452. logging.error(
  453. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  454. )
  455. raise Exception(
  456. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  457. )
  458. if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
  459. logging.error(
  460. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  461. )
  462. raise Exception(
  463. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  464. )
  465. if self.ema is not None:
  466. result["ema_decay"] = self.ema.get_decay() * 1000
  467. return result
  468. @staticmethod
  469. def compute_var(y):
  470. y = y.view(-1, y.size(-1))
  471. if dist.is_initialized():
  472. zc = torch.tensor(y.size(0)).cuda()
  473. zs = y.sum(dim=0)
  474. zss = (y ** 2).sum(dim=0)
  475. dist.all_reduce(zc)
  476. dist.all_reduce(zs)
  477. dist.all_reduce(zss)
  478. var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
  479. return torch.sqrt(var + 1e-6).mean()
  480. else:
  481. return torch.sqrt(y.var(dim=0) + 1e-6).mean()
  482. def extract_features(
  483. self, xs_pad, ilens, mask=False, layer=None
  484. ):
  485. res = self.forward(
  486. xs_pad,
  487. ilens,
  488. mask=mask,
  489. features_only=True,
  490. layer=layer,
  491. )
  492. return res
  493. def remove_pretraining_modules(self, last_layer=None):
  494. self.final_proj = None
  495. self.ema = None
  496. if last_layer is not None:
  497. self.encoder.layers = nn.ModuleList(
  498. l for i, l in enumerate(self.encoder.layers) if i <= last_layer
  499. )
  500. def output_size(self) -> int:
  501. return self.encoder_embed_dim