data2vec.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import Optional
  9. from typing import Tuple
  10. import torch
  11. from funasr.layers.abs_normalize import AbsNormalize
  12. from funasr.models.base_model import FunASRModel
  13. from funasr.models.encoder.abs_encoder import AbsEncoder
  14. from funasr.models.frontend.abs_frontend import AbsFrontend
  15. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  16. from funasr.models.specaug.abs_specaug import AbsSpecAug
  17. from funasr.torch_utils.device_funcs import force_gatherable
  18. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  19. from torch.cuda.amp import autocast
  20. else:
  21. # Nothing to do if torch<1.6.0
  22. @contextmanager
  23. def autocast(enabled=True):
  24. yield
  25. class Data2VecPretrainModel(FunASRModel):
  26. """Data2Vec Pretrain model"""
  27. def __init__(
  28. self,
  29. frontend: Optional[AbsFrontend],
  30. specaug: Optional[AbsSpecAug],
  31. normalize: Optional[AbsNormalize],
  32. encoder: AbsEncoder,
  33. preencoder: Optional[AbsPreEncoder] = None,
  34. ):
  35. super().__init__()
  36. self.frontend = frontend
  37. self.specaug = specaug
  38. self.normalize = normalize
  39. self.preencoder = preencoder
  40. self.encoder = encoder
  41. self.num_updates = 0
  42. def forward(
  43. self,
  44. speech: torch.Tensor,
  45. speech_lengths: torch.Tensor,
  46. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  47. """Frontend + Encoder + Calc loss
  48. Args:
  49. speech: (Batch, Length, ...)
  50. speech_lengths: (Batch, )
  51. """
  52. # Check that batch_size is unified
  53. assert (
  54. speech.shape[0]
  55. == speech_lengths.shape[0]
  56. ), (speech.shape, speech_lengths.shape)
  57. self.encoder.set_num_updates(self.num_updates)
  58. # 1. Encoder
  59. encoder_out = self.encode(speech, speech_lengths)
  60. losses = encoder_out["losses"]
  61. loss = sum(losses.values())
  62. sample_size = encoder_out["sample_size"]
  63. loss = loss.sum() / sample_size
  64. target_var = float(encoder_out["target_var"])
  65. pred_var = float(encoder_out["pred_var"])
  66. ema_decay = float(encoder_out["ema_decay"])
  67. stats = dict(
  68. loss=torch.clone(loss.detach()),
  69. target_var=target_var,
  70. pred_var=pred_var,
  71. ema_decay=ema_decay,
  72. )
  73. loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
  74. return loss, stats, weight
  75. def collect_feats(
  76. self,
  77. speech: torch.Tensor,
  78. speech_lengths: torch.Tensor
  79. ) -> Dict[str, torch.Tensor]:
  80. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  81. return {"feats": feats, "feats_lengths": feats_lengths}
  82. def encode(
  83. self,
  84. speech: torch.Tensor,
  85. speech_lengths: torch.Tensor,
  86. ):
  87. """Frontend + Encoder.
  88. Args:
  89. speech: (Batch, Length, ...)
  90. speech_lengths: (Batch, )
  91. """
  92. with autocast(False):
  93. # 1. Extract feats
  94. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  95. # 2. Data augmentation
  96. if self.specaug is not None and self.training:
  97. feats, feats_lengths = self.specaug(feats, feats_lengths)
  98. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  99. if self.normalize is not None:
  100. feats, feats_lengths = self.normalize(feats, feats_lengths)
  101. # Pre-encoder, e.g. used for raw input data
  102. if self.preencoder is not None:
  103. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  104. # 4. Forward encoder
  105. if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
  106. speech_lengths = None
  107. encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
  108. return encoder_out
  109. def _extract_feats(
  110. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. assert speech_lengths.dim() == 1, speech_lengths.shape
  113. # for data-parallel
  114. speech = speech[:, : speech_lengths.max()]
  115. if self.frontend is not None:
  116. # Frontend
  117. # e.g. STFT and Feature extract
  118. # data_loader may send time-domain signal in this case
  119. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  120. feats, feats_lengths = self.frontend(speech, speech_lengths)
  121. else:
  122. # No frontend and no feature extract
  123. feats, feats_lengths = speech, speech_lengths
  124. return feats, feats_lengths
  125. def set_num_updates(self, num_updates):
  126. self.num_updates = num_updates
  127. def get_num_updates(self):
  128. return self.num_updates