e2e_tp.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import numpy as np
  11. from typeguard import check_argument_types
  12. from funasr.models.encoder.abs_encoder import AbsEncoder
  13. from funasr.models.frontend.abs_frontend import AbsFrontend
  14. from funasr.models.predictor.cif import mae_loss
  15. from funasr.modules.add_sos_eos import add_sos_eos
  16. from funasr.modules.nets_utils import make_pad_mask, pad_list
  17. from funasr.torch_utils.device_funcs import force_gatherable
  18. from funasr.models.base_model import FunASRModel
  19. from funasr.models.predictor.cif import CifPredictorV3
  20. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  21. from torch.cuda.amp import autocast
  22. else:
  23. # Nothing to do if torch<1.6.0
  24. @contextmanager
  25. def autocast(enabled=True):
  26. yield
  27. class TimestampPredictor(FunASRModel):
  28. """
  29. Author: Speech Lab of DAMO Academy, Alibaba Group
  30. """
  31. def __init__(
  32. self,
  33. frontend: Optional[AbsFrontend],
  34. encoder: AbsEncoder,
  35. predictor: CifPredictorV3,
  36. predictor_bias: int = 0,
  37. token_list=None,
  38. ):
  39. assert check_argument_types()
  40. super().__init__()
  41. # note that eos is the same as sos (equivalent ID)
  42. self.frontend = frontend
  43. self.encoder = encoder
  44. self.encoder.interctc_use_conditioning = False
  45. self.predictor = predictor
  46. self.predictor_bias = predictor_bias
  47. self.criterion_pre = mae_loss()
  48. self.token_list = token_list
  49. def forward(
  50. self,
  51. speech: torch.Tensor,
  52. speech_lengths: torch.Tensor,
  53. text: torch.Tensor,
  54. text_lengths: torch.Tensor,
  55. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  56. """Frontend + Encoder + Decoder + Calc loss
  57. Args:
  58. speech: (Batch, Length, ...)
  59. speech_lengths: (Batch, )
  60. text: (Batch, Length)
  61. text_lengths: (Batch,)
  62. """
  63. assert text_lengths.dim() == 1, text_lengths.shape
  64. # Check that batch_size is unified
  65. assert (
  66. speech.shape[0]
  67. == speech_lengths.shape[0]
  68. == text.shape[0]
  69. == text_lengths.shape[0]
  70. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  71. batch_size = speech.shape[0]
  72. # for data-parallel
  73. text = text[:, : text_lengths.max()]
  74. speech = speech[:, :speech_lengths.max()]
  75. # 1. Encoder
  76. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  77. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  78. encoder_out.device)
  79. if self.predictor_bias == 1:
  80. _, text = add_sos_eos(text, 1, 2, -1)
  81. text_lengths = text_lengths + self.predictor_bias
  82. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
  83. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  84. loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
  85. loss = loss_pre
  86. stats = dict()
  87. # Collect Attn branch stats
  88. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  89. stats["loss"] = torch.clone(loss.detach())
  90. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  91. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  92. return loss, stats, weight
  93. def encode(
  94. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  95. ) -> Tuple[torch.Tensor, torch.Tensor]:
  96. """Frontend + Encoder. Note that this method is used by asr_inference.py
  97. Args:
  98. speech: (Batch, Length, ...)
  99. speech_lengths: (Batch, )
  100. """
  101. with autocast(False):
  102. # 1. Extract feats
  103. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  104. # 4. Forward encoder
  105. # feats: (Batch, Length, Dim)
  106. # -> encoder_out: (Batch, Length2, Dim2)
  107. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  108. return encoder_out, encoder_out_lens
  109. def _extract_feats(
  110. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. assert speech_lengths.dim() == 1, speech_lengths.shape
  113. # for data-parallel
  114. speech = speech[:, : speech_lengths.max()]
  115. if self.frontend is not None:
  116. # Frontend
  117. # e.g. STFT and Feature extract
  118. # data_loader may send time-domain signal in this case
  119. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  120. feats, feats_lengths = self.frontend(speech, speech_lengths)
  121. else:
  122. # No frontend and no feature extract
  123. feats, feats_lengths = speech, speech_lengths
  124. return feats, feats_lengths
  125. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  126. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  127. encoder_out.device)
  128. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  129. encoder_out_mask,
  130. token_num)
  131. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  132. def collect_feats(
  133. self,
  134. speech: torch.Tensor,
  135. speech_lengths: torch.Tensor,
  136. text: torch.Tensor,
  137. text_lengths: torch.Tensor,
  138. ) -> Dict[str, torch.Tensor]:
  139. if self.extract_feats_in_collect_stats:
  140. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  141. else:
  142. # Generate dummy stats if extract_feats_in_collect_stats is False
  143. logging.warning(
  144. "Generating dummy stats for feats and feats_lengths, "
  145. "because encoder_conf.extract_feats_in_collect_stats is "
  146. f"{self.extract_feats_in_collect_stats}"
  147. )
  148. feats, feats_lengths = speech, speech_lengths
  149. return {"feats": feats, "feats_lengths": feats_lengths}