base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from collections import namedtuple
  12. from dataclasses import dataclass
  13. from functools import partial
  14. from omegaconf import MISSING, II
  15. from typing import Optional, Callable
  16. from funasr.models.emotion2vec.fairseq_modules import compute_mask_indices
  17. from funasr.models.emotion2vec.fairseq_modules import GradMultiply
  18. from funasr.models.emotion2vec.fairseq_modules import index_put
  19. logger = logging.getLogger(__name__)
  20. MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
  21. MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
  22. class ModalitySpecificEncoder(nn.Module):
  23. def __init__(
  24. self,
  25. modality_cfg,
  26. embed_dim: int,
  27. local_encoder: nn.Module,
  28. project_features: nn.Module,
  29. fixed_positional_encoder: Optional[nn.Module],
  30. relative_positional_encoder: Optional[nn.Module],
  31. context_encoder: nn.Module,
  32. decoder: nn.Module,
  33. get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
  34. ):
  35. super().__init__()
  36. self.modality_cfg = modality_cfg
  37. self.local_encoder = local_encoder
  38. self.project_features = project_features
  39. self.fixed_positional_encoder = fixed_positional_encoder
  40. self.relative_positional_encoder = relative_positional_encoder
  41. self.context_encoder = context_encoder
  42. self.decoder = decoder
  43. self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
  44. self.local_grad_mult = self.modality_cfg.local_grad_mult
  45. self.extra_tokens = None
  46. if modality_cfg.num_extra_tokens > 0:
  47. self.extra_tokens = nn.Parameter(
  48. torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
  49. )
  50. if not modality_cfg.init_extra_token_zero:
  51. nn.init.normal_(self.extra_tokens)
  52. elif self.extra_tokens.size(1) > 1:
  53. nn.init.normal_(self.extra_tokens[:, 1:])
  54. self.alibi_scale = None
  55. if self.get_alibi_bias is not None:
  56. self.alibi_scale = nn.Parameter(
  57. torch.full(
  58. (
  59. (modality_cfg.prenet_depth + modality_cfg.model_depth)
  60. if modality_cfg.learned_alibi_scale_per_layer
  61. else 1,
  62. 1,
  63. self.modality_cfg.num_alibi_heads
  64. if modality_cfg.learned_alibi_scale_per_head
  65. else 1,
  66. 1,
  67. 1,
  68. ),
  69. modality_cfg.alibi_scale,
  70. dtype=torch.float,
  71. ),
  72. requires_grad=modality_cfg.learned_alibi_scale,
  73. )
  74. if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
  75. assert modality_cfg.alibi_max_pos is not None
  76. alibi_bias = self.get_alibi_bias(
  77. batch_size=1,
  78. time_steps=modality_cfg.alibi_max_pos,
  79. heads=modality_cfg.num_alibi_heads,
  80. scale=1.0,
  81. dtype=torch.float,
  82. device="cpu",
  83. )
  84. self.alibi_bias = nn.Parameter(alibi_bias)
  85. self.get_alibi_bias = partial(
  86. _learned_alibi_bias, alibi_bias=self.alibi_bias
  87. )
  88. def upgrade_state_dict_named(self, state_dict, name):
  89. k = f"{name}.alibi_scale"
  90. if k in state_dict and state_dict[k].dim() == 4:
  91. state_dict[k] = state_dict[k].unsqueeze(0)
  92. return state_dict
  93. def convert_padding_mask(self, x, padding_mask):
  94. return padding_mask
  95. def decoder_input(self, x, mask_info: MaskInfo):
  96. inp_drop = self.modality_cfg.decoder.input_dropout
  97. if inp_drop > 0:
  98. x = F.dropout(x, inp_drop, training=self.training, inplace=True)
  99. num_extra = self.modality_cfg.num_extra_tokens
  100. if mask_info is not None:
  101. num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
  102. mask_tokens = x.new_empty(
  103. x.size(0),
  104. num_masked,
  105. x.size(-1),
  106. ).normal_(0, self.modality_cfg.mask_noise_std)
  107. x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
  108. x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
  109. if self.modality_cfg.decoder.add_positions_masked:
  110. assert self.fixed_positional_encoder is not None
  111. pos = self.fixed_positional_encoder(x, None)
  112. x = x + (pos * mask_info.mask.unsqueeze(-1))
  113. else:
  114. x = x[:, num_extra:]
  115. if self.modality_cfg.decoder.add_positions_all:
  116. assert self.fixed_positional_encoder is not None
  117. x = x + self.fixed_positional_encoder(x, None)
  118. return x, mask_info
  119. def local_features(self, features):
  120. if self.local_grad_mult > 0:
  121. if self.local_grad_mult == 1.0:
  122. x = self.local_encoder(features)
  123. else:
  124. x = GradMultiply.apply(
  125. self.local_encoder(features), self.local_grad_mult
  126. )
  127. else:
  128. with torch.no_grad():
  129. x = self.local_encoder(features)
  130. x = self.project_features(x)
  131. return x
  132. def contextualized_features(
  133. self,
  134. x,
  135. padding_mask,
  136. mask,
  137. remove_masked,
  138. clone_batch: int = 1,
  139. mask_seeds: Optional[torch.Tensor] = None,
  140. precomputed_mask=None,
  141. ):
  142. if padding_mask is not None:
  143. padding_mask = self.convert_padding_mask(x, padding_mask)
  144. local_features = x
  145. if mask and clone_batch == 1:
  146. local_features = local_features.clone()
  147. orig_B, orig_T, _ = x.shape
  148. pre_mask_B = orig_B
  149. mask_info = None
  150. x_pos = None
  151. if self.fixed_positional_encoder is not None:
  152. x = x + self.fixed_positional_encoder(x, padding_mask)
  153. if mask:
  154. if clone_batch > 1:
  155. x = x.repeat_interleave(clone_batch, 0)
  156. if mask_seeds is not None:
  157. clone_hash = [
  158. int(hash((mask_seeds.seed, ind)) % 1e10)
  159. for ind in range(clone_batch - 1)
  160. ]
  161. clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
  162. id = mask_seeds.ids
  163. id = id.repeat_interleave(clone_batch, 0)
  164. id = id.view(-1, clone_batch) + clone_hash.to(id)
  165. id = id.view(-1)
  166. mask_seeds = MaskSeed(
  167. seed=mask_seeds.seed, update=mask_seeds.update, ids=id
  168. )
  169. if padding_mask is not None:
  170. padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
  171. x, mask_info = self.compute_mask(
  172. x,
  173. padding_mask,
  174. mask_seed=mask_seeds,
  175. apply=self.relative_positional_encoder is not None or not remove_masked,
  176. precomputed_mask=precomputed_mask,
  177. )
  178. if self.relative_positional_encoder is not None:
  179. x_pos = self.relative_positional_encoder(x)
  180. masked_padding_mask = padding_mask
  181. if mask and remove_masked:
  182. x = mask_info.x_unmasked
  183. if x_pos is not None:
  184. x = x + gather_unmasked(x_pos, mask_info)
  185. if padding_mask is not None and padding_mask.any():
  186. masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
  187. if not masked_padding_mask.any():
  188. masked_padding_mask = None
  189. else:
  190. masked_padding_mask = None
  191. elif x_pos is not None:
  192. x = x + x_pos
  193. alibi_bias = None
  194. alibi_scale = self.alibi_scale
  195. if self.get_alibi_bias is not None:
  196. alibi_bias = self.get_alibi_bias(
  197. batch_size=pre_mask_B,
  198. time_steps=orig_T,
  199. heads=self.modality_cfg.num_alibi_heads,
  200. dtype=torch.float32,
  201. device=x.device,
  202. )
  203. if alibi_scale is not None:
  204. alibi_scale = alibi_scale.clamp_min(0)
  205. if alibi_scale.size(0) == 1:
  206. alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
  207. alibi_scale = None
  208. if clone_batch > 1:
  209. alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
  210. if mask_info is not None and remove_masked:
  211. alibi_bias = masked_alibi(alibi_bias, mask_info)
  212. if self.extra_tokens is not None:
  213. num = self.extra_tokens.size(1)
  214. x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
  215. if masked_padding_mask is not None:
  216. # B x T
  217. masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
  218. if alibi_bias is not None:
  219. # B x H x T x T
  220. alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
  221. x = self.context_encoder(
  222. x,
  223. masked_padding_mask,
  224. alibi_bias,
  225. alibi_scale[: self.modality_cfg.prenet_depth]
  226. if alibi_scale is not None
  227. else None,
  228. )
  229. return {
  230. "x": x,
  231. "local_features": local_features,
  232. "padding_mask": masked_padding_mask,
  233. "alibi_bias": alibi_bias,
  234. "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
  235. if alibi_scale is not None and alibi_scale.size(0) > 1
  236. else alibi_scale,
  237. "encoder_mask": mask_info,
  238. }
  239. def forward(
  240. self,
  241. features,
  242. padding_mask,
  243. mask: bool,
  244. remove_masked: bool,
  245. clone_batch: int = 1,
  246. mask_seeds: Optional[torch.Tensor] = None,
  247. precomputed_mask=None,
  248. ):
  249. x = self.local_features(features)
  250. return self.contextualized_features(
  251. x,
  252. padding_mask,
  253. mask,
  254. remove_masked,
  255. clone_batch,
  256. mask_seeds,
  257. precomputed_mask,
  258. )
  259. def reset_parameters(self):
  260. pass
  261. def compute_mask(
  262. self,
  263. x,
  264. padding_mask,
  265. mask_seed: Optional[MaskSeed],
  266. apply,
  267. precomputed_mask,
  268. ):
  269. if precomputed_mask is not None:
  270. mask = precomputed_mask
  271. mask_info = self.make_maskinfo(x, mask)
  272. else:
  273. B, T, C = x.shape
  274. cfg = self.modality_cfg
  275. mask_prob = cfg.mask_prob
  276. if (
  277. cfg.mask_prob_min is not None
  278. and cfg.mask_prob_min >= 0
  279. and cfg.mask_prob_min < mask_prob
  280. ):
  281. mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
  282. if mask_prob > 0:
  283. if cfg.mask_length == 1:
  284. mask_info = random_masking(x, mask_prob, mask_seed)
  285. else:
  286. if self.modality_cfg.inverse_mask:
  287. mask_prob = 1 - mask_prob
  288. mask = compute_mask_indices(
  289. (B, T),
  290. padding_mask,
  291. mask_prob,
  292. cfg.mask_length,
  293. min_masks=1,
  294. require_same_masks=True,
  295. mask_dropout=cfg.mask_dropout,
  296. add_masks=cfg.add_masks,
  297. seed=mask_seed.seed if mask_seed is not None else None,
  298. epoch=mask_seed.update if mask_seed is not None else None,
  299. indices=mask_seed.ids if mask_seed is not None else None,
  300. )
  301. mask = torch.from_numpy(mask).to(device=x.device)
  302. if self.modality_cfg.inverse_mask:
  303. mask = 1 - mask
  304. mask_info = self.make_maskinfo(x, mask)
  305. else:
  306. mask_info = None
  307. if apply:
  308. x = self.apply_mask(x, mask_info)
  309. return x, mask_info
  310. def make_maskinfo(self, x, mask, shape=None):
  311. if shape is None:
  312. B, T, D = x.shape
  313. else:
  314. B, T, D = shape
  315. mask = mask.to(torch.uint8)
  316. ids_shuffle = mask.argsort(dim=1)
  317. ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
  318. len_keep = T - mask[0].sum()
  319. if self.modality_cfg.keep_masked_pct > 0:
  320. len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
  321. ids_keep = ids_shuffle[:, :len_keep]
  322. if shape is not None:
  323. x_unmasked = None
  324. else:
  325. ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
  326. x_unmasked = torch.gather(x, dim=1, index=ids_keep)
  327. mask_info = MaskInfo(
  328. x_unmasked=x_unmasked,
  329. mask=mask,
  330. ids_restore=ids_restore,
  331. ids_keep=ids_keep,
  332. )
  333. return mask_info
  334. def apply_mask(self, x, mask_info):
  335. cfg = self.modality_cfg
  336. B, T, C = x.shape
  337. if mask_info is not None:
  338. mask = mask_info.mask
  339. if cfg.encoder_zero_mask:
  340. x = x * (1 - mask.type_as(x).unsqueeze(-1))
  341. else:
  342. num_masks = mask.sum().item()
  343. masks = x.new_empty(num_masks, x.size(-1)).normal_(
  344. 0, cfg.mask_noise_std
  345. )
  346. x = index_put(x, mask, masks)
  347. if cfg.mask_channel_prob > 0:
  348. mask_channel = compute_mask_indices(
  349. (B, C),
  350. None,
  351. cfg.mask_channel_prob,
  352. cfg.mask_channel_length,
  353. )
  354. mask_channel = (
  355. torch.from_numpy(mask_channel)
  356. .to(x.device)
  357. .unsqueeze(1)
  358. .expand(-1, T, -1)
  359. )
  360. x = index_put(x, mask_channel, 0)
  361. return x
  362. def remove_pretraining_modules(self, keep_decoder=False):
  363. if not keep_decoder:
  364. self.decoder = None
  365. def get_annealed_rate(start, end, curr_step, total_steps):
  366. if curr_step >= total_steps:
  367. return end
  368. r = end - start
  369. pct_remaining = 1 - curr_step / total_steps
  370. return end - r * pct_remaining
  371. # adapted from MAE
  372. def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
  373. N, L, D = x.shape # batch, length, dim
  374. len_keep = int(L * (1 - mask_ratio))
  375. generator = None
  376. if mask_seed is not None:
  377. seed = int(
  378. hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
  379. )
  380. generator = torch.Generator(device=x.device)
  381. generator.manual_seed(seed)
  382. noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
  383. # sort noise for each sample
  384. ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
  385. ids_restore = ids_shuffle.argsort(dim=1)
  386. # keep the first subset
  387. ids_keep = ids_shuffle[:, :len_keep]
  388. ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
  389. x_unmasked = torch.gather(x, dim=1, index=ids_keep)
  390. # generate the binary mask: 0 is keep, 1 is remove
  391. mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
  392. mask[:, :len_keep] = 0
  393. # unshuffle to get the binary mask
  394. mask = torch.gather(mask, dim=1, index=ids_restore)
  395. ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
  396. return MaskInfo(
  397. x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
  398. )
  399. def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
  400. return torch.gather(
  401. x,
  402. dim=1,
  403. index=mask_info.ids_keep,
  404. )
  405. def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
  406. return torch.gather(
  407. x,
  408. dim=1,
  409. index=mask_info.ids_keep[..., 0], # ignore the feature dimension
  410. )
  411. def get_alibi(
  412. max_positions: int,
  413. attention_heads: int,
  414. dims: int = 1,
  415. distance: str = "manhattan",
  416. ):
  417. def get_slopes(n):
  418. def get_slopes_power_of_2(n):
  419. start = 2 ** (-(2 ** -(math.log2(n) - 3)))
  420. ratio = start
  421. return [start * ratio**i for i in range(n)]
  422. # In the paper, we only train models that have 2^a heads for some
  423. # a. This function has some good properties that only occur when
  424. # the input is a power of 2. To maintain that even when the number
  425. # of heads is not a power of 2, we use this workaround.
  426. if math.log2(n).is_integer():
  427. return get_slopes_power_of_2(n)
  428. else:
  429. closest_power_of_2 = 2 ** math.floor(math.log2(n))
  430. return (
  431. get_slopes_power_of_2(closest_power_of_2)
  432. + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
  433. )
  434. maxpos = max_positions
  435. attn_heads = attention_heads
  436. slopes = torch.Tensor(get_slopes(attn_heads))
  437. if dims == 1:
  438. # prepare alibi position linear bias. Note that wav2vec2 is non
  439. # autoregressive model so we want a symmetric mask with 0 on the
  440. # diagonal and other wise linear decreasing valuees
  441. pos_bias = (
  442. torch.abs(
  443. torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
  444. )
  445. * -1
  446. )
  447. elif dims == 2:
  448. if distance == "manhattan":
  449. df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
  450. elif distance == "euclidean":
  451. df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
  452. n = math.sqrt(max_positions)
  453. assert n.is_integer(), n
  454. n = int(n)
  455. pos_bias = torch.zeros((max_positions, max_positions))
  456. for i in range(n):
  457. for j in range(n):
  458. for k in range(n):
  459. for l in range(n):
  460. new_x = i * n + j
  461. new_y = k * n + l
  462. pos_bias[new_x, new_y] = -df(i, j, k, l)
  463. else:
  464. raise Exception(f"unsupported number of alibi dims: {dims}")
  465. alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
  466. attn_heads, -1, -1
  467. )
  468. return alibi_bias
  469. def get_alibi_bias(
  470. alibi_biases,
  471. batch_size,
  472. time_steps,
  473. heads,
  474. dtype,
  475. device,
  476. dims=1,
  477. distance="manhattan",
  478. ):
  479. cache_key = f"{dims}_{heads}_{distance}"
  480. buffered = alibi_biases.get(cache_key, None)
  481. target_size = heads * batch_size
  482. if (
  483. buffered is None
  484. or buffered.size(0) < target_size
  485. or buffered.size(1) < time_steps
  486. or buffered.dtype != dtype
  487. or buffered.device != device
  488. ):
  489. bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
  490. bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
  491. buffered = (
  492. get_alibi(bt, heads, dims=dims, distance=distance)
  493. .to(dtype=dtype, device=device)
  494. .repeat(bn, 1, 1)
  495. )
  496. alibi_biases[cache_key] = buffered
  497. b = buffered[:target_size, :time_steps, :time_steps]
  498. b = b.view(batch_size, heads, time_steps, time_steps)
  499. return b
  500. def _learned_alibi_bias(
  501. alibi_bias,
  502. batch_size,
  503. time_steps,
  504. heads,
  505. scale,
  506. dtype,
  507. device,
  508. ):
  509. assert alibi_bias.size(1) == heads, alibi_bias.shape
  510. assert alibi_bias.dtype == dtype, alibi_bias.dtype
  511. assert alibi_bias.device == device, alibi_bias.device
  512. if alibi_bias.size(-1) < time_steps:
  513. psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
  514. alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
  515. alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
  516. return alibi_bias[..., :time_steps, :time_steps]
  517. def masked_alibi(alibi_bias, mask_info):
  518. H = alibi_bias.size(1)
  519. orig_bias = alibi_bias
  520. index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
  521. alibi_bias = torch.gather(
  522. orig_bias,
  523. dim=-2,
  524. index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
  525. )
  526. alibi_bias = torch.gather(
  527. alibi_bias,
  528. dim=-1,
  529. index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
  530. )
  531. return alibi_bias