data2vec.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import Optional
  9. from typing import Tuple
  10. import torch
  11. from typeguard import check_argument_types
  12. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  13. from funasr.torch_utils.device_funcs import force_gatherable
  14. from funasr.models.base_model import FunASRModel
  15. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  16. from torch.cuda.amp import autocast
  17. else:
  18. # Nothing to do if torch<1.6.0
  19. @contextmanager
  20. def autocast(enabled=True):
  21. yield
  22. class Data2VecPretrainModel(FunASRModel):
  23. """Data2Vec Pretrain model"""
  24. def __init__(
  25. self,
  26. frontend: Optional[torch.nn.Module],
  27. specaug: Optional[torch.nn.Module],
  28. normalize: Optional[torch.nn.Module],
  29. preencoder: Optional[AbsPreEncoder],
  30. encoder: torch.nn.Module,
  31. ):
  32. assert check_argument_types()
  33. super().__init__()
  34. self.frontend = frontend
  35. self.specaug = specaug
  36. self.normalize = normalize
  37. self.preencoder = preencoder
  38. self.encoder = encoder
  39. self.num_updates = 0
  40. def forward(
  41. self,
  42. speech: torch.Tensor,
  43. speech_lengths: torch.Tensor,
  44. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  45. """Frontend + Encoder + Calc loss
  46. Args:
  47. speech: (Batch, Length, ...)
  48. speech_lengths: (Batch, )
  49. """
  50. # Check that batch_size is unified
  51. assert (
  52. speech.shape[0]
  53. == speech_lengths.shape[0]
  54. ), (speech.shape, speech_lengths.shape)
  55. self.encoder.set_num_updates(self.num_updates)
  56. # 1. Encoder
  57. encoder_out = self.encode(speech, speech_lengths)
  58. losses = encoder_out["losses"]
  59. loss = sum(losses.values())
  60. sample_size = encoder_out["sample_size"]
  61. loss = loss.sum() / sample_size
  62. target_var = float(encoder_out["target_var"])
  63. pred_var = float(encoder_out["pred_var"])
  64. ema_decay = float(encoder_out["ema_decay"])
  65. stats = dict(
  66. loss=torch.clone(loss.detach()),
  67. target_var=target_var,
  68. pred_var=pred_var,
  69. ema_decay=ema_decay,
  70. )
  71. loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
  72. return loss, stats, weight
  73. def collect_feats(
  74. self,
  75. speech: torch.Tensor,
  76. speech_lengths: torch.Tensor
  77. ) -> Dict[str, torch.Tensor]:
  78. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  79. return {"feats": feats, "feats_lengths": feats_lengths}
  80. def encode(
  81. self,
  82. speech: torch.Tensor,
  83. speech_lengths: torch.Tensor,
  84. ):
  85. """Frontend + Encoder.
  86. Args:
  87. speech: (Batch, Length, ...)
  88. speech_lengths: (Batch, )
  89. """
  90. with autocast(False):
  91. # 1. Extract feats
  92. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  93. # 2. Data augmentation
  94. if self.specaug is not None and self.training:
  95. feats, feats_lengths = self.specaug(feats, feats_lengths)
  96. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  97. if self.normalize is not None:
  98. feats, feats_lengths = self.normalize(feats, feats_lengths)
  99. # Pre-encoder, e.g. used for raw input data
  100. if self.preencoder is not None:
  101. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  102. # 4. Forward encoder
  103. if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
  104. speech_lengths = None
  105. encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
  106. return encoder_out
  107. def _extract_feats(
  108. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  109. ) -> Tuple[torch.Tensor, torch.Tensor]:
  110. assert speech_lengths.dim() == 1, speech_lengths.shape
  111. # for data-parallel
  112. speech = speech[:, : speech_lengths.max()]
  113. if self.frontend is not None:
  114. # Frontend
  115. # e.g. STFT and Feature extract
  116. # data_loader may send time-domain signal in this case
  117. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  118. feats, feats_lengths = self.frontend(speech, speech_lengths)
  119. else:
  120. # No frontend and no feature extract
  121. feats, feats_lengths = speech, speech_lengths
  122. return feats, feats_lengths
  123. def set_num_updates(self, num_updates):
  124. self.num_updates = num_updates
  125. def get_num_updates(self):
  126. return self.num_updates