data2vec.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import Optional
  9. from typing import Tuple
  10. import torch
  11. import torch.nn as nn
  12. # from funasr.layers.abs_normalize import AbsNormalize
  13. # from funasr.models.base_model import FunASRModel
  14. # from funasr.models.encoder.abs_encoder import AbsEncoder
  15. from funasr.frontends.abs_frontend import AbsFrontend
  16. # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  17. # from funasr.models.specaug.abs_specaug import AbsSpecAug
  18. from funasr.train_utils.device_funcs import force_gatherable
  19. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  20. from torch.cuda.amp import autocast
  21. else:
  22. # Nothing to do if torch<1.6.0
  23. @contextmanager
  24. def autocast(enabled=True):
  25. yield
  26. class Data2VecPretrainModel(nn.Module):
  27. """Data2Vec Pretrain model"""
  28. def __init__(
  29. self,
  30. frontend = None,
  31. specaug = None,
  32. normalize = None,
  33. encoder = None,
  34. preencoder = None,
  35. ):
  36. super().__init__()
  37. self.frontend = frontend
  38. self.specaug = specaug
  39. self.normalize = normalize
  40. self.preencoder = preencoder
  41. self.encoder = encoder
  42. self.num_updates = 0
  43. def forward(
  44. self,
  45. speech: torch.Tensor,
  46. speech_lengths: torch.Tensor,
  47. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  48. """Frontend + Encoder + Calc loss
  49. Args:
  50. speech: (Batch, Length, ...)
  51. speech_lengths: (Batch, )
  52. """
  53. # Check that batch_size is unified
  54. assert (
  55. speech.shape[0]
  56. == speech_lengths.shape[0]
  57. ), (speech.shape, speech_lengths.shape)
  58. self.encoder.set_num_updates(self.num_updates)
  59. # 1. Encoder
  60. encoder_out = self.encode(speech, speech_lengths)
  61. losses = encoder_out["losses"]
  62. loss = sum(losses.values())
  63. sample_size = encoder_out["sample_size"]
  64. loss = loss.sum() / sample_size
  65. target_var = float(encoder_out["target_var"])
  66. pred_var = float(encoder_out["pred_var"])
  67. ema_decay = float(encoder_out["ema_decay"])
  68. stats = dict(
  69. loss=torch.clone(loss.detach()),
  70. target_var=target_var,
  71. pred_var=pred_var,
  72. ema_decay=ema_decay,
  73. )
  74. loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
  75. return loss, stats, weight
  76. def collect_feats(
  77. self,
  78. speech: torch.Tensor,
  79. speech_lengths: torch.Tensor
  80. ) -> Dict[str, torch.Tensor]:
  81. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  82. return {"feats": feats, "feats_lengths": feats_lengths}
  83. def encode(
  84. self,
  85. speech: torch.Tensor,
  86. speech_lengths: torch.Tensor,
  87. ):
  88. """Frontend + Encoder.
  89. Args:
  90. speech: (Batch, Length, ...)
  91. speech_lengths: (Batch, )
  92. """
  93. with autocast(False):
  94. # 1. Extract feats
  95. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  96. # 2. Data augmentation
  97. if self.specaug is not None and self.training:
  98. feats, feats_lengths = self.specaug(feats, feats_lengths)
  99. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  100. if self.normalize is not None:
  101. feats, feats_lengths = self.normalize(feats, feats_lengths)
  102. # Pre-encoder, e.g. used for raw input data
  103. if self.preencoder is not None:
  104. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  105. # 4. Forward encoder
  106. if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
  107. speech_lengths = None
  108. encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
  109. return encoder_out
  110. def _extract_feats(
  111. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  112. ) -> Tuple[torch.Tensor, torch.Tensor]:
  113. assert speech_lengths.dim() == 1, speech_lengths.shape
  114. # for data-parallel
  115. speech = speech[:, : speech_lengths.max()]
  116. if self.frontend is not None:
  117. # Frontend
  118. # e.g. STFT and Feature extract
  119. # data_loader may send time-domain signal in this case
  120. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  121. feats, feats_lengths = self.frontend(speech, speech_lengths)
  122. else:
  123. # No frontend and no feature extract
  124. feats, feats_lengths = speech, speech_lengths
  125. return feats, feats_lengths
  126. def set_num_updates(self, num_updates):
  127. self.num_updates = num_updates
  128. def get_num_updates(self):
  129. return self.num_updates