e2e_sv.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. """
  2. Author: Speech Lab, Alibaba Group, China
  3. """
  4. import logging
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import torch
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  23. from funasr.models.specaug.abs_specaug import AbsSpecAug
  24. from funasr.modules.add_sos_eos import add_sos_eos
  25. from funasr.modules.e2e_asr_common import ErrorCalculator
  26. from funasr.modules.nets_utils import th_accuracy
  27. from funasr.torch_utils.device_funcs import force_gatherable
  28. from funasr.models.base_model import FunASRModel
  29. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  30. from torch.cuda.amp import autocast
  31. else:
  32. # Nothing to do if torch<1.6.0
  33. @contextmanager
  34. def autocast(enabled=True):
  35. yield
  36. class ESPnetSVModel(FunASRModel):
  37. """CTC-attention hybrid Encoder-Decoder model"""
  38. def __init__(
  39. self,
  40. vocab_size: int,
  41. token_list: Union[Tuple[str, ...], List[str]],
  42. frontend: Optional[AbsFrontend],
  43. specaug: Optional[AbsSpecAug],
  44. normalize: Optional[AbsNormalize],
  45. preencoder: Optional[AbsPreEncoder],
  46. encoder: AbsEncoder,
  47. postencoder: Optional[AbsPostEncoder],
  48. pooling_layer: torch.nn.Module,
  49. decoder: AbsDecoder,
  50. ):
  51. super().__init__()
  52. # note that eos is the same as sos (equivalent ID)
  53. self.vocab_size = vocab_size
  54. self.token_list = token_list.copy()
  55. self.frontend = frontend
  56. self.specaug = specaug
  57. self.normalize = normalize
  58. self.preencoder = preencoder
  59. self.postencoder = postencoder
  60. self.encoder = encoder
  61. self.pooling_layer = pooling_layer
  62. self.decoder = decoder
  63. def forward(
  64. self,
  65. speech: torch.Tensor,
  66. speech_lengths: torch.Tensor,
  67. text: torch.Tensor,
  68. text_lengths: torch.Tensor,
  69. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  70. """Frontend + Encoder + Decoder + Calc loss
  71. Args:
  72. speech: (Batch, Length, ...)
  73. speech_lengths: (Batch, )
  74. text: (Batch, Length)
  75. text_lengths: (Batch,)
  76. """
  77. assert text_lengths.dim() == 1, text_lengths.shape
  78. # Check that batch_size is unified
  79. assert (
  80. speech.shape[0]
  81. == speech_lengths.shape[0]
  82. == text.shape[0]
  83. == text_lengths.shape[0]
  84. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  85. batch_size = speech.shape[0]
  86. # for data-parallel
  87. text = text[:, : text_lengths.max()]
  88. # 1. Encoder
  89. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  90. intermediate_outs = None
  91. if isinstance(encoder_out, tuple):
  92. intermediate_outs = encoder_out[1]
  93. encoder_out = encoder_out[0]
  94. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  95. loss_ctc, cer_ctc = None, None
  96. loss_transducer, cer_transducer, wer_transducer = None, None, None
  97. stats = dict()
  98. # 1. CTC branch
  99. if self.ctc_weight != 0.0:
  100. loss_ctc, cer_ctc = self._calc_ctc_loss(
  101. encoder_out, encoder_out_lens, text, text_lengths
  102. )
  103. # Collect CTC branch stats
  104. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  105. stats["cer_ctc"] = cer_ctc
  106. # Intermediate CTC (optional)
  107. loss_interctc = 0.0
  108. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  109. for layer_idx, intermediate_out in intermediate_outs:
  110. # we assume intermediate_out has the same length & padding
  111. # as those of encoder_out
  112. loss_ic, cer_ic = self._calc_ctc_loss(
  113. intermediate_out, encoder_out_lens, text, text_lengths
  114. )
  115. loss_interctc = loss_interctc + loss_ic
  116. # Collect Intermedaite CTC stats
  117. stats["loss_interctc_layer{}".format(layer_idx)] = (
  118. loss_ic.detach() if loss_ic is not None else None
  119. )
  120. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  121. loss_interctc = loss_interctc / len(intermediate_outs)
  122. # calculate whole encoder loss
  123. loss_ctc = (
  124. 1 - self.interctc_weight
  125. ) * loss_ctc + self.interctc_weight * loss_interctc
  126. if self.use_transducer_decoder:
  127. # 2a. Transducer decoder branch
  128. (
  129. loss_transducer,
  130. cer_transducer,
  131. wer_transducer,
  132. ) = self._calc_transducer_loss(
  133. encoder_out,
  134. encoder_out_lens,
  135. text,
  136. )
  137. if loss_ctc is not None:
  138. loss = loss_transducer + (self.ctc_weight * loss_ctc)
  139. else:
  140. loss = loss_transducer
  141. # Collect Transducer branch stats
  142. stats["loss_transducer"] = (
  143. loss_transducer.detach() if loss_transducer is not None else None
  144. )
  145. stats["cer_transducer"] = cer_transducer
  146. stats["wer_transducer"] = wer_transducer
  147. else:
  148. # 2b. Attention decoder branch
  149. if self.ctc_weight != 1.0:
  150. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  151. encoder_out, encoder_out_lens, text, text_lengths
  152. )
  153. # 3. CTC-Att loss definition
  154. if self.ctc_weight == 0.0:
  155. loss = loss_att
  156. elif self.ctc_weight == 1.0:
  157. loss = loss_ctc
  158. else:
  159. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  160. # Collect Attn branch stats
  161. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  162. stats["acc"] = acc_att
  163. stats["cer"] = cer_att
  164. stats["wer"] = wer_att
  165. # Collect total loss stats
  166. stats["loss"] = torch.clone(loss.detach())
  167. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  168. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  169. return loss, stats, weight
  170. def collect_feats(
  171. self,
  172. speech: torch.Tensor,
  173. speech_lengths: torch.Tensor,
  174. text: torch.Tensor,
  175. text_lengths: torch.Tensor,
  176. ) -> Dict[str, torch.Tensor]:
  177. if self.extract_feats_in_collect_stats:
  178. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  179. else:
  180. # Generate dummy stats if extract_feats_in_collect_stats is False
  181. logging.warning(
  182. "Generating dummy stats for feats and feats_lengths, "
  183. "because encoder_conf.extract_feats_in_collect_stats is "
  184. f"{self.extract_feats_in_collect_stats}"
  185. )
  186. feats, feats_lengths = speech, speech_lengths
  187. return {"feats": feats, "feats_lengths": feats_lengths}
  188. def encode(
  189. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  190. ) -> Tuple[torch.Tensor, torch.Tensor]:
  191. """Frontend + Encoder. Note that this method is used by asr_inference.py
  192. Args:
  193. speech: (Batch, Length, ...)
  194. speech_lengths: (Batch, )
  195. """
  196. with autocast(False):
  197. # 1. Extract feats
  198. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  199. # 2. Data augmentation
  200. if self.specaug is not None and self.training:
  201. feats, feats_lengths = self.specaug(feats, feats_lengths)
  202. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  203. if self.normalize is not None:
  204. feats, feats_lengths = self.normalize(feats, feats_lengths)
  205. # Pre-encoder, e.g. used for raw input data
  206. if self.preencoder is not None:
  207. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  208. # 4. Forward encoder
  209. # feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2)
  210. encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
  211. # Post-encoder, e.g. NLU
  212. if self.postencoder is not None:
  213. encoder_out, encoder_out_lens = self.postencoder(
  214. encoder_out, encoder_out_lens
  215. )
  216. return encoder_out, encoder_out_lens
  217. def _extract_feats(
  218. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  219. ) -> Tuple[torch.Tensor, torch.Tensor]:
  220. assert speech_lengths.dim() == 1, speech_lengths.shape
  221. # for data-parallel
  222. speech = speech[:, : speech_lengths.max()]
  223. if self.frontend is not None:
  224. # Frontend
  225. # e.g. STFT and Feature extract
  226. # data_loader may send time-domain signal in this case
  227. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  228. feats, feats_lengths = self.frontend(speech, speech_lengths)
  229. else:
  230. # No frontend and no feature extract
  231. feats, feats_lengths = speech, speech_lengths
  232. return feats, feats_lengths