data2vec.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import Optional
  9. from typing import Tuple
  10. import torch
  11. from typeguard import check_argument_types
  12. from funasr.layers.abs_normalize import AbsNormalize
  13. from funasr.models.encoder.abs_encoder import AbsEncoder
  14. from funasr.models.frontend.abs_frontend import AbsFrontend
  15. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  16. from funasr.models.specaug.abs_specaug import AbsSpecAug
  17. from funasr.torch_utils.device_funcs import force_gatherable
  18. from funasr.train.abs_espnet_model import AbsESPnetModel
  19. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  20. from torch.cuda.amp import autocast
  21. else:
  22. # Nothing to do if torch<1.6.0
  23. @contextmanager
  24. def autocast(enabled=True):
  25. yield
  26. class Data2VecPretrainModel(AbsESPnetModel):
  27. """Data2Vec Pretrain model"""
  28. def __init__(
  29. self,
  30. frontend: Optional[AbsFrontend],
  31. specaug: Optional[AbsSpecAug],
  32. normalize: Optional[AbsNormalize],
  33. preencoder: Optional[AbsPreEncoder],
  34. encoder: AbsEncoder,
  35. ):
  36. assert check_argument_types()
  37. super().__init__()
  38. self.frontend = frontend
  39. self.specaug = specaug
  40. self.normalize = normalize
  41. self.preencoder = preencoder
  42. self.encoder = encoder
  43. self.num_updates = 0
  44. def forward(
  45. self,
  46. speech: torch.Tensor,
  47. speech_lengths: torch.Tensor,
  48. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  49. """Frontend + Encoder + Calc loss
  50. Args:
  51. speech: (Batch, Length, ...)
  52. speech_lengths: (Batch, )
  53. """
  54. # Check that batch_size is unified
  55. assert (
  56. speech.shape[0]
  57. == speech_lengths.shape[0]
  58. ), (speech.shape, speech_lengths.shape)
  59. self.encoder.set_num_updates(self.num_updates)
  60. # 1. Encoder
  61. encoder_out = self.encode(speech, speech_lengths)
  62. losses = encoder_out["losses"]
  63. loss = sum(losses.values())
  64. sample_size = encoder_out["sample_size"]
  65. loss = loss.sum() / sample_size
  66. target_var = float(encoder_out["target_var"])
  67. pred_var = float(encoder_out["pred_var"])
  68. ema_decay = float(encoder_out["ema_decay"])
  69. stats = dict(
  70. loss=torch.clone(loss.detach()),
  71. target_var=target_var,
  72. pred_var=pred_var,
  73. ema_decay=ema_decay,
  74. )
  75. loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
  76. return loss, stats, weight
  77. def collect_feats(
  78. self,
  79. speech: torch.Tensor,
  80. speech_lengths: torch.Tensor
  81. ) -> Dict[str, torch.Tensor]:
  82. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  83. return {"feats": feats, "feats_lengths": feats_lengths}
  84. def encode(
  85. self,
  86. speech: torch.Tensor,
  87. speech_lengths: torch.Tensor,
  88. ):
  89. """Frontend + Encoder.
  90. Args:
  91. speech: (Batch, Length, ...)
  92. speech_lengths: (Batch, )
  93. """
  94. with autocast(False):
  95. # 1. Extract feats
  96. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  97. # 2. Data augmentation
  98. if self.specaug is not None and self.training:
  99. feats, feats_lengths = self.specaug(feats, feats_lengths)
  100. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  101. if self.normalize is not None:
  102. feats, feats_lengths = self.normalize(feats, feats_lengths)
  103. # Pre-encoder, e.g. used for raw input data
  104. if self.preencoder is not None:
  105. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  106. # 4. Forward encoder
  107. if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
  108. speech_lengths = None
  109. encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
  110. return encoder_out
  111. def _extract_feats(
  112. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  113. ) -> Tuple[torch.Tensor, torch.Tensor]:
  114. assert speech_lengths.dim() == 1, speech_lengths.shape
  115. # for data-parallel
  116. speech = speech[:, : speech_lengths.max()]
  117. if self.frontend is not None:
  118. # Frontend
  119. # e.g. STFT and Feature extract
  120. # data_loader may send time-domain signal in this case
  121. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  122. feats, feats_lengths = self.frontend(speech, speech_lengths)
  123. else:
  124. # No frontend and no feature extract
  125. feats, feats_lengths = speech, speech_lengths
  126. return feats, feats_lengths
  127. def set_num_updates(self, num_updates):
  128. self.num_updates = num_updates
  129. def get_num_updates(self):
  130. return self.num_updates