e2e_tp.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import numpy as np
  11. from funasr.models.encoder.abs_encoder import AbsEncoder
  12. from funasr.models.frontend.abs_frontend import AbsFrontend
  13. from funasr.models.predictor.cif import mae_loss
  14. from funasr.modules.add_sos_eos import add_sos_eos
  15. from funasr.modules.nets_utils import make_pad_mask, pad_list
  16. from funasr.torch_utils.device_funcs import force_gatherable
  17. from funasr.models.base_model import FunASRModel
  18. from funasr.models.predictor.cif import CifPredictorV3
  19. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  20. from torch.cuda.amp import autocast
  21. else:
  22. # Nothing to do if torch<1.6.0
  23. @contextmanager
  24. def autocast(enabled=True):
  25. yield
  26. class TimestampPredictor(FunASRModel):
  27. """
  28. Author: Speech Lab of DAMO Academy, Alibaba Group
  29. """
  30. def __init__(
  31. self,
  32. frontend: Optional[AbsFrontend],
  33. encoder: AbsEncoder,
  34. predictor: CifPredictorV3,
  35. predictor_bias: int = 0,
  36. token_list=None,
  37. ):
  38. super().__init__()
  39. # note that eos is the same as sos (equivalent ID)
  40. self.frontend = frontend
  41. self.encoder = encoder
  42. self.encoder.interctc_use_conditioning = False
  43. self.predictor = predictor
  44. self.predictor_bias = predictor_bias
  45. self.criterion_pre = mae_loss()
  46. self.token_list = token_list
  47. def forward(
  48. self,
  49. speech: torch.Tensor,
  50. speech_lengths: torch.Tensor,
  51. text: torch.Tensor,
  52. text_lengths: torch.Tensor,
  53. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  54. """Frontend + Encoder + Decoder + Calc loss
  55. Args:
  56. speech: (Batch, Length, ...)
  57. speech_lengths: (Batch, )
  58. text: (Batch, Length)
  59. text_lengths: (Batch,)
  60. """
  61. assert text_lengths.dim() == 1, text_lengths.shape
  62. # Check that batch_size is unified
  63. assert (
  64. speech.shape[0]
  65. == speech_lengths.shape[0]
  66. == text.shape[0]
  67. == text_lengths.shape[0]
  68. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  69. batch_size = speech.shape[0]
  70. # for data-parallel
  71. text = text[:, : text_lengths.max()]
  72. speech = speech[:, :speech_lengths.max()]
  73. # 1. Encoder
  74. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  75. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  76. encoder_out.device)
  77. if self.predictor_bias == 1:
  78. _, text = add_sos_eos(text, 1, 2, -1)
  79. text_lengths = text_lengths + self.predictor_bias
  80. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
  81. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  82. loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
  83. loss = loss_pre
  84. stats = dict()
  85. # Collect Attn branch stats
  86. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  87. stats["loss"] = torch.clone(loss.detach())
  88. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  89. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  90. return loss, stats, weight
  91. def encode(
  92. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  93. ) -> Tuple[torch.Tensor, torch.Tensor]:
  94. """Frontend + Encoder. Note that this method is used by asr_inference.py
  95. Args:
  96. speech: (Batch, Length, ...)
  97. speech_lengths: (Batch, )
  98. """
  99. with autocast(False):
  100. # 1. Extract feats
  101. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  102. # 4. Forward encoder
  103. # feats: (Batch, Length, Dim)
  104. # -> encoder_out: (Batch, Length2, Dim2)
  105. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  106. return encoder_out, encoder_out_lens
  107. def _extract_feats(
  108. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  109. ) -> Tuple[torch.Tensor, torch.Tensor]:
  110. assert speech_lengths.dim() == 1, speech_lengths.shape
  111. # for data-parallel
  112. speech = speech[:, : speech_lengths.max()]
  113. if self.frontend is not None:
  114. # Frontend
  115. # e.g. STFT and Feature extract
  116. # data_loader may send time-domain signal in this case
  117. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  118. feats, feats_lengths = self.frontend(speech, speech_lengths)
  119. else:
  120. # No frontend and no feature extract
  121. feats, feats_lengths = speech, speech_lengths
  122. return feats, feats_lengths
  123. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  124. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  125. encoder_out.device)
  126. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  127. encoder_out_mask,
  128. token_num)
  129. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  130. def collect_feats(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. ) -> Dict[str, torch.Tensor]:
  137. if self.extract_feats_in_collect_stats:
  138. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  139. else:
  140. # Generate dummy stats if extract_feats_in_collect_stats is False
  141. logging.warning(
  142. "Generating dummy stats for feats and feats_lengths, "
  143. "because encoder_conf.extract_feats_in_collect_stats is "
  144. f"{self.extract_feats_in_collect_stats}"
  145. )
  146. feats, feats_lengths = speech, speech_lengths
  147. return {"feats": feats, "feats_lengths": feats_lengths}