data2vec_encoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. import torch
  8. import torch.distributed as dist
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from funasr.models.encoder.abs_encoder import AbsEncoder
  12. from funasr.modules.data2vec.data_utils import compute_mask_indices
  13. from funasr.modules.data2vec.ema_module import EMAModule
  14. from funasr.modules.data2vec.grad_multiply import GradMultiply
  15. from funasr.modules.data2vec.wav2vec2 import (
  16. ConvFeatureExtractionModel,
  17. TransformerEncoder,
  18. )
  19. from funasr.modules.nets_utils import make_pad_mask
  20. def get_annealed_rate(start, end, curr_step, total_steps):
  21. r = end - start
  22. pct_remaining = 1 - curr_step / total_steps
  23. return end - r * pct_remaining
  24. class Data2VecEncoder(AbsEncoder):
  25. def __init__(
  26. self,
  27. # for ConvFeatureExtractionModel
  28. input_size: int = None,
  29. extractor_mode: str = None,
  30. conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
  31. # for Transformer Encoder
  32. ## model architecture
  33. layer_type: str = "transformer",
  34. layer_norm_first: bool = False,
  35. encoder_layers: int = 12,
  36. encoder_embed_dim: int = 768,
  37. encoder_ffn_embed_dim: int = 3072,
  38. encoder_attention_heads: int = 12,
  39. activation_fn: str = "gelu",
  40. ## dropouts
  41. dropout: float = 0.1,
  42. attention_dropout: float = 0.1,
  43. activation_dropout: float = 0.0,
  44. encoder_layerdrop: float = 0.0,
  45. dropout_input: float = 0.0,
  46. dropout_features: float = 0.0,
  47. ## grad settings
  48. feature_grad_mult: float = 1.0,
  49. ## masking
  50. mask_prob: float = 0.65,
  51. mask_length: int = 10,
  52. mask_selection: str = "static",
  53. mask_other: int = 0,
  54. no_mask_overlap: bool = False,
  55. mask_min_space: int = 1,
  56. require_same_masks: bool = True, # if set as True, collate_fn should be clipping
  57. mask_dropout: float = 0.0,
  58. ## channel masking
  59. mask_channel_length: int = 10,
  60. mask_channel_prob: float = 0.0,
  61. mask_channel_before: bool = False,
  62. mask_channel_selection: str = "static",
  63. mask_channel_other: int = 0,
  64. no_mask_channel_overlap: bool = False,
  65. mask_channel_min_space: int = 1,
  66. ## positional embeddings
  67. conv_pos: int = 128,
  68. conv_pos_groups: int = 16,
  69. pos_conv_depth: int = 1,
  70. max_positions: int = 100000,
  71. # EMA module
  72. average_top_k_layers: int = 8,
  73. layer_norm_target_layer: bool = False,
  74. instance_norm_target_layer: bool = False,
  75. instance_norm_targets: bool = False,
  76. layer_norm_targets: bool = False,
  77. batch_norm_target_layer: bool = False,
  78. group_norm_target_layer: bool = False,
  79. ema_decay: float = 0.999,
  80. ema_end_decay: float = 0.9999,
  81. ema_anneal_end_step: int = 100000,
  82. ema_transformer_only: bool = True,
  83. ema_layers_only: bool = True,
  84. min_target_var: float = 0.1,
  85. min_pred_var: float = 0.01,
  86. # Loss
  87. loss_beta: float = 0.0,
  88. loss_scale: float = None,
  89. # FP16 optimization
  90. required_seq_len_multiple: int = 2,
  91. ):
  92. super().__init__()
  93. # ConvFeatureExtractionModel
  94. self.conv_feature_layers = conv_feature_layers
  95. feature_enc_layers = eval(conv_feature_layers)
  96. self.extractor_embed = feature_enc_layers[-1][0]
  97. self.feature_extractor = ConvFeatureExtractionModel(
  98. conv_layers=feature_enc_layers,
  99. dropout=0.0,
  100. mode=extractor_mode,
  101. in_d=input_size,
  102. )
  103. # Transformer Encoder
  104. ## model architecture
  105. self.layer_type = layer_type
  106. self.layer_norm_first = layer_norm_first
  107. self.encoder_layers = encoder_layers
  108. self.encoder_embed_dim = encoder_embed_dim
  109. self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
  110. self.encoder_attention_heads = encoder_attention_heads
  111. self.activation_fn = activation_fn
  112. ## dropout
  113. self.dropout = dropout
  114. self.attention_dropout = attention_dropout
  115. self.activation_dropout = activation_dropout
  116. self.encoder_layerdrop = encoder_layerdrop
  117. self.dropout_input = dropout_input
  118. self.dropout_features = dropout_features
  119. ## grad settings
  120. self.feature_grad_mult = feature_grad_mult
  121. ## masking
  122. self.mask_prob = mask_prob
  123. self.mask_length = mask_length
  124. self.mask_selection = mask_selection
  125. self.mask_other = mask_other
  126. self.no_mask_overlap = no_mask_overlap
  127. self.mask_min_space = mask_min_space
  128. self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
  129. self.mask_dropout = mask_dropout
  130. ## channel masking
  131. self.mask_channel_length = mask_channel_length
  132. self.mask_channel_prob = mask_channel_prob
  133. self.mask_channel_before = mask_channel_before
  134. self.mask_channel_selection = mask_channel_selection
  135. self.mask_channel_other = mask_channel_other
  136. self.no_mask_channel_overlap = no_mask_channel_overlap
  137. self.mask_channel_min_space = mask_channel_min_space
  138. ## positional embeddings
  139. self.conv_pos = conv_pos
  140. self.conv_pos_groups = conv_pos_groups
  141. self.pos_conv_depth = pos_conv_depth
  142. self.max_positions = max_positions
  143. self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
  144. self.encoder = TransformerEncoder(
  145. dropout=self.dropout,
  146. encoder_embed_dim=self.encoder_embed_dim,
  147. required_seq_len_multiple=required_seq_len_multiple,
  148. pos_conv_depth=self.pos_conv_depth,
  149. conv_pos=self.conv_pos,
  150. conv_pos_groups=self.conv_pos_groups,
  151. # transformer layers
  152. layer_type=self.layer_type,
  153. encoder_layers=self.encoder_layers,
  154. encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
  155. encoder_attention_heads=self.encoder_attention_heads,
  156. attention_dropout=self.attention_dropout,
  157. activation_dropout=self.activation_dropout,
  158. activation_fn=self.activation_fn,
  159. layer_norm_first=self.layer_norm_first,
  160. encoder_layerdrop=self.encoder_layerdrop,
  161. max_positions=self.max_positions,
  162. )
  163. ## projections and dropouts
  164. self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
  165. self.dropout_input = nn.Dropout(self.dropout_input)
  166. self.dropout_features = nn.Dropout(self.dropout_features)
  167. self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
  168. self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
  169. # EMA module
  170. self.average_top_k_layers = average_top_k_layers
  171. self.layer_norm_target_layer = layer_norm_target_layer
  172. self.instance_norm_target_layer = instance_norm_target_layer
  173. self.instance_norm_targets = instance_norm_targets
  174. self.layer_norm_targets = layer_norm_targets
  175. self.batch_norm_target_layer = batch_norm_target_layer
  176. self.group_norm_target_layer = group_norm_target_layer
  177. self.ema_decay = ema_decay
  178. self.ema_end_decay = ema_end_decay
  179. self.ema_anneal_end_step = ema_anneal_end_step
  180. self.ema_transformer_only = ema_transformer_only
  181. self.ema_layers_only = ema_layers_only
  182. self.min_target_var = min_target_var
  183. self.min_pred_var = min_pred_var
  184. self.ema = None
  185. # Loss
  186. self.loss_beta = loss_beta
  187. self.loss_scale = loss_scale
  188. # FP16 optimization
  189. self.required_seq_len_multiple = required_seq_len_multiple
  190. self.num_updates = 0
  191. logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
  192. def make_ema_teacher(self):
  193. skip_keys = set()
  194. if self.ema_layers_only:
  195. self.ema_transformer_only = True
  196. for k, _ in self.encoder.pos_conv.named_parameters():
  197. skip_keys.add(f"pos_conv.{k}")
  198. self.ema = EMAModule(
  199. self.encoder if self.ema_transformer_only else self,
  200. ema_decay=self.ema_decay,
  201. ema_fp32=True,
  202. skip_keys=skip_keys,
  203. )
  204. def set_num_updates(self, num_updates):
  205. if self.ema is None and self.final_proj is not None:
  206. logging.info("Making EMA Teacher")
  207. self.make_ema_teacher()
  208. elif self.training and self.ema is not None:
  209. if self.ema_decay != self.ema_end_decay:
  210. if num_updates >= self.ema_anneal_end_step:
  211. decay = self.ema_end_decay
  212. else:
  213. decay = get_annealed_rate(
  214. self.ema_decay,
  215. self.ema_end_decay,
  216. num_updates,
  217. self.ema_anneal_end_step,
  218. )
  219. self.ema.set_decay(decay)
  220. if self.ema.get_decay() < 1:
  221. self.ema.step(self.encoder if self.ema_transformer_only else self)
  222. self.num_updates = num_updates
  223. def apply_mask(
  224. self,
  225. x,
  226. padding_mask,
  227. mask_indices=None,
  228. mask_channel_indices=None,
  229. ):
  230. B, T, C = x.shape
  231. if self.mask_channel_prob > 0 and self.mask_channel_before:
  232. mask_channel_indices = compute_mask_indices(
  233. (B, C),
  234. None,
  235. self.mask_channel_prob,
  236. self.mask_channel_length,
  237. self.mask_channel_selection,
  238. self.mask_channel_other,
  239. no_overlap=self.no_mask_channel_overlap,
  240. min_space=self.mask_channel_min_space,
  241. )
  242. mask_channel_indices = (
  243. torch.from_numpy(mask_channel_indices)
  244. .to(x.device)
  245. .unsqueeze(1)
  246. .expand(-1, T, -1)
  247. )
  248. x[mask_channel_indices] = 0
  249. if self.mask_prob > 0:
  250. if mask_indices is None:
  251. mask_indices = compute_mask_indices(
  252. (B, T),
  253. padding_mask,
  254. self.mask_prob,
  255. self.mask_length,
  256. self.mask_selection,
  257. self.mask_other,
  258. min_masks=1,
  259. no_overlap=self.no_mask_overlap,
  260. min_space=self.mask_min_space,
  261. require_same_masks=self.require_same_masks,
  262. mask_dropout=self.mask_dropout,
  263. )
  264. mask_indices = torch.from_numpy(mask_indices).to(x.device)
  265. x[mask_indices] = self.mask_emb
  266. else:
  267. mask_indices = None
  268. if self.mask_channel_prob > 0 and not self.mask_channel_before:
  269. if mask_channel_indices is None:
  270. mask_channel_indices = compute_mask_indices(
  271. (B, C),
  272. None,
  273. self.mask_channel_prob,
  274. self.mask_channel_length,
  275. self.mask_channel_selection,
  276. self.mask_channel_other,
  277. no_overlap=self.no_mask_channel_overlap,
  278. min_space=self.mask_channel_min_space,
  279. )
  280. mask_channel_indices = (
  281. torch.from_numpy(mask_channel_indices)
  282. .to(x.device)
  283. .unsqueeze(1)
  284. .expand(-1, T, -1)
  285. )
  286. x[mask_channel_indices] = 0
  287. return x, mask_indices
  288. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  289. """
  290. Computes the output length of the convolutional layers
  291. """
  292. def _conv_out_length(input_length, kernel_size, stride):
  293. return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
  294. conv_cfg_list = eval(self.conv_feature_layers)
  295. for i in range(len(conv_cfg_list)):
  296. input_lengths = _conv_out_length(
  297. input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
  298. )
  299. return input_lengths.to(torch.long)
  300. def forward(
  301. self,
  302. xs_pad,
  303. ilens=None,
  304. mask=False,
  305. features_only=True,
  306. layer=None,
  307. mask_indices=None,
  308. mask_channel_indices=None,
  309. padding_count=None,
  310. ):
  311. # create padding_mask by ilens
  312. if ilens is not None:
  313. padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
  314. else:
  315. padding_mask = None
  316. features = xs_pad
  317. if self.feature_grad_mult > 0:
  318. features = self.feature_extractor(features)
  319. if self.feature_grad_mult != 1.0:
  320. features = GradMultiply.apply(features, self.feature_grad_mult)
  321. else:
  322. with torch.no_grad():
  323. features = self.feature_extractor(features)
  324. features = features.transpose(1, 2)
  325. features = self.layer_norm(features)
  326. orig_padding_mask = padding_mask
  327. if padding_mask is not None:
  328. input_lengths = (1 - padding_mask.long()).sum(-1)
  329. # apply conv formula to get real output_lengths
  330. output_lengths = self._get_feat_extract_output_lengths(input_lengths)
  331. padding_mask = torch.zeros(
  332. features.shape[:2], dtype=features.dtype, device=features.device
  333. )
  334. # these two operations makes sure that all values
  335. # before the output lengths indices are attended to
  336. padding_mask[
  337. (
  338. torch.arange(padding_mask.shape[0], device=padding_mask.device),
  339. output_lengths - 1,
  340. )
  341. ] = 1
  342. padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
  343. else:
  344. padding_mask = None
  345. if self.post_extract_proj is not None:
  346. features = self.post_extract_proj(features)
  347. pre_encoder_features = None
  348. if self.ema_transformer_only:
  349. pre_encoder_features = features.clone()
  350. features = self.dropout_input(features)
  351. if mask:
  352. x, mask_indices = self.apply_mask(
  353. features,
  354. padding_mask,
  355. mask_indices=mask_indices,
  356. mask_channel_indices=mask_channel_indices,
  357. )
  358. else:
  359. x = features
  360. mask_indices = None
  361. x, layer_results = self.encoder(
  362. x,
  363. padding_mask=padding_mask,
  364. layer=layer,
  365. )
  366. if features_only:
  367. encoder_out_lens = (1 - padding_mask.long()).sum(1)
  368. return x, encoder_out_lens, None
  369. result = {
  370. "losses": {},
  371. "padding_mask": padding_mask,
  372. "x": x,
  373. }
  374. with torch.no_grad():
  375. self.ema.model.eval()
  376. if self.ema_transformer_only:
  377. y, layer_results = self.ema.model.extract_features(
  378. pre_encoder_features,
  379. padding_mask=padding_mask,
  380. min_layer=self.encoder_layers - self.average_top_k_layers,
  381. )
  382. y = {
  383. "x": y,
  384. "padding_mask": padding_mask,
  385. "layer_results": layer_results,
  386. }
  387. else:
  388. y = self.ema.model.extract_features(
  389. source=xs_pad,
  390. padding_mask=orig_padding_mask,
  391. mask=False,
  392. )
  393. target_layer_results = [l[2] for l in y["layer_results"]]
  394. permuted = False
  395. if self.instance_norm_target_layer or self.batch_norm_target_layer:
  396. target_layer_results = [
  397. tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
  398. ]
  399. permuted = True
  400. if self.batch_norm_target_layer:
  401. target_layer_results = [
  402. F.batch_norm(
  403. tl.float(), running_mean=None, running_var=None, training=True
  404. )
  405. for tl in target_layer_results
  406. ]
  407. if self.instance_norm_target_layer:
  408. target_layer_results = [
  409. F.instance_norm(tl.float()) for tl in target_layer_results
  410. ]
  411. if permuted:
  412. target_layer_results = [
  413. tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
  414. ]
  415. if self.group_norm_target_layer:
  416. target_layer_results = [
  417. F.layer_norm(tl.float(), tl.shape[-2:])
  418. for tl in target_layer_results
  419. ]
  420. if self.layer_norm_target_layer:
  421. target_layer_results = [
  422. F.layer_norm(tl.float(), tl.shape[-1:])
  423. for tl in target_layer_results
  424. ]
  425. y = sum(target_layer_results) / len(target_layer_results)
  426. if self.layer_norm_targets:
  427. y = F.layer_norm(y.float(), y.shape[-1:])
  428. if self.instance_norm_targets:
  429. y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
  430. if not permuted:
  431. y = y.transpose(0, 1)
  432. y = y[mask_indices]
  433. x = x[mask_indices]
  434. x = self.final_proj(x)
  435. sz = x.size(-1)
  436. if self.loss_beta == 0:
  437. loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
  438. else:
  439. loss = F.smooth_l1_loss(
  440. x.float(), y.float(), reduction="none", beta=self.loss_beta
  441. ).sum(dim=-1)
  442. if self.loss_scale is not None:
  443. scale = self.loss_scale
  444. else:
  445. scale = 1 / math.sqrt(sz)
  446. result["losses"]["regression"] = loss.sum() * scale
  447. if "sample_size" not in result:
  448. result["sample_size"] = loss.numel()
  449. with torch.no_grad():
  450. result["target_var"] = self.compute_var(y)
  451. result["pred_var"] = self.compute_var(x.float())
  452. if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
  453. logging.error(
  454. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  455. )
  456. raise Exception(
  457. f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
  458. )
  459. if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
  460. logging.error(
  461. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  462. )
  463. raise Exception(
  464. f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
  465. )
  466. if self.ema is not None:
  467. result["ema_decay"] = self.ema.get_decay() * 1000
  468. return result
  469. @staticmethod
  470. def compute_var(y):
  471. y = y.view(-1, y.size(-1))
  472. if dist.is_initialized():
  473. zc = torch.tensor(y.size(0)).cuda()
  474. zs = y.sum(dim=0)
  475. zss = (y ** 2).sum(dim=0)
  476. dist.all_reduce(zc)
  477. dist.all_reduce(zs)
  478. dist.all_reduce(zss)
  479. var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
  480. return torch.sqrt(var + 1e-6).mean()
  481. else:
  482. return torch.sqrt(y.var(dim=0) + 1e-6).mean()
  483. def extract_features(
  484. self, xs_pad, ilens, mask=False, layer=None
  485. ):
  486. res = self.forward(
  487. xs_pad,
  488. ilens,
  489. mask=mask,
  490. features_only=True,
  491. layer=layer,
  492. )
  493. return res
  494. def remove_pretraining_modules(self, last_layer=None):
  495. self.final_proj = None
  496. self.ema = None
  497. if last_layer is not None:
  498. self.encoder.layers = nn.ModuleList(
  499. l for i, l in enumerate(self.encoder.layers) if i <= last_layer
  500. )
  501. def output_size(self) -> int:
  502. return self.encoder_embed_dim