e2e_sv.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. """
  2. Author: Speech Lab, Alibaba Group, China
  3. """
  4. import logging
  5. from contextlib import contextmanager
  6. from distutils.version import LooseVersion
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import torch
  13. from typeguard import check_argument_types
  14. from funasr.layers.abs_normalize import AbsNormalize
  15. from funasr.losses.label_smoothing_loss import (
  16. LabelSmoothingLoss, # noqa: H301
  17. )
  18. from funasr.models.ctc import CTC
  19. from funasr.models.decoder.abs_decoder import AbsDecoder
  20. from funasr.models.encoder.abs_encoder import AbsEncoder
  21. from funasr.models.frontend.abs_frontend import AbsFrontend
  22. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  23. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  24. from funasr.models.specaug.abs_specaug import AbsSpecAug
  25. from funasr.modules.add_sos_eos import add_sos_eos
  26. from funasr.modules.e2e_asr_common import ErrorCalculator
  27. from funasr.modules.nets_utils import th_accuracy
  28. from funasr.torch_utils.device_funcs import force_gatherable
  29. from funasr.models.base_model import FunASRModel
  30. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  31. from torch.cuda.amp import autocast
  32. else:
  33. # Nothing to do if torch<1.6.0
  34. @contextmanager
  35. def autocast(enabled=True):
  36. yield
  37. class ESPnetSVModel(FunASRModel):
  38. """CTC-attention hybrid Encoder-Decoder model"""
  39. def __init__(
  40. self,
  41. vocab_size: int,
  42. token_list: Union[Tuple[str, ...], List[str]],
  43. frontend: Optional[AbsFrontend],
  44. specaug: Optional[AbsSpecAug],
  45. normalize: Optional[AbsNormalize],
  46. preencoder: Optional[AbsPreEncoder],
  47. encoder: AbsEncoder,
  48. postencoder: Optional[AbsPostEncoder],
  49. pooling_layer: torch.nn.Module,
  50. decoder: AbsDecoder,
  51. ):
  52. assert check_argument_types()
  53. super().__init__()
  54. # note that eos is the same as sos (equivalent ID)
  55. self.vocab_size = vocab_size
  56. self.token_list = token_list.copy()
  57. self.frontend = frontend
  58. self.specaug = specaug
  59. self.normalize = normalize
  60. self.preencoder = preencoder
  61. self.postencoder = postencoder
  62. self.encoder = encoder
  63. self.pooling_layer = pooling_layer
  64. self.decoder = decoder
  65. def forward(
  66. self,
  67. speech: torch.Tensor,
  68. speech_lengths: torch.Tensor,
  69. text: torch.Tensor,
  70. text_lengths: torch.Tensor,
  71. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  72. """Frontend + Encoder + Decoder + Calc loss
  73. Args:
  74. speech: (Batch, Length, ...)
  75. speech_lengths: (Batch, )
  76. text: (Batch, Length)
  77. text_lengths: (Batch,)
  78. """
  79. assert text_lengths.dim() == 1, text_lengths.shape
  80. # Check that batch_size is unified
  81. assert (
  82. speech.shape[0]
  83. == speech_lengths.shape[0]
  84. == text.shape[0]
  85. == text_lengths.shape[0]
  86. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  87. batch_size = speech.shape[0]
  88. # for data-parallel
  89. text = text[:, : text_lengths.max()]
  90. # 1. Encoder
  91. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  92. intermediate_outs = None
  93. if isinstance(encoder_out, tuple):
  94. intermediate_outs = encoder_out[1]
  95. encoder_out = encoder_out[0]
  96. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  97. loss_ctc, cer_ctc = None, None
  98. loss_transducer, cer_transducer, wer_transducer = None, None, None
  99. stats = dict()
  100. # 1. CTC branch
  101. if self.ctc_weight != 0.0:
  102. loss_ctc, cer_ctc = self._calc_ctc_loss(
  103. encoder_out, encoder_out_lens, text, text_lengths
  104. )
  105. # Collect CTC branch stats
  106. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  107. stats["cer_ctc"] = cer_ctc
  108. # Intermediate CTC (optional)
  109. loss_interctc = 0.0
  110. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  111. for layer_idx, intermediate_out in intermediate_outs:
  112. # we assume intermediate_out has the same length & padding
  113. # as those of encoder_out
  114. loss_ic, cer_ic = self._calc_ctc_loss(
  115. intermediate_out, encoder_out_lens, text, text_lengths
  116. )
  117. loss_interctc = loss_interctc + loss_ic
  118. # Collect Intermedaite CTC stats
  119. stats["loss_interctc_layer{}".format(layer_idx)] = (
  120. loss_ic.detach() if loss_ic is not None else None
  121. )
  122. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  123. loss_interctc = loss_interctc / len(intermediate_outs)
  124. # calculate whole encoder loss
  125. loss_ctc = (
  126. 1 - self.interctc_weight
  127. ) * loss_ctc + self.interctc_weight * loss_interctc
  128. if self.use_transducer_decoder:
  129. # 2a. Transducer decoder branch
  130. (
  131. loss_transducer,
  132. cer_transducer,
  133. wer_transducer,
  134. ) = self._calc_transducer_loss(
  135. encoder_out,
  136. encoder_out_lens,
  137. text,
  138. )
  139. if loss_ctc is not None:
  140. loss = loss_transducer + (self.ctc_weight * loss_ctc)
  141. else:
  142. loss = loss_transducer
  143. # Collect Transducer branch stats
  144. stats["loss_transducer"] = (
  145. loss_transducer.detach() if loss_transducer is not None else None
  146. )
  147. stats["cer_transducer"] = cer_transducer
  148. stats["wer_transducer"] = wer_transducer
  149. else:
  150. # 2b. Attention decoder branch
  151. if self.ctc_weight != 1.0:
  152. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  153. encoder_out, encoder_out_lens, text, text_lengths
  154. )
  155. # 3. CTC-Att loss definition
  156. if self.ctc_weight == 0.0:
  157. loss = loss_att
  158. elif self.ctc_weight == 1.0:
  159. loss = loss_ctc
  160. else:
  161. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  162. # Collect Attn branch stats
  163. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  164. stats["acc"] = acc_att
  165. stats["cer"] = cer_att
  166. stats["wer"] = wer_att
  167. # Collect total loss stats
  168. stats["loss"] = torch.clone(loss.detach())
  169. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  170. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  171. return loss, stats, weight
  172. def collect_feats(
  173. self,
  174. speech: torch.Tensor,
  175. speech_lengths: torch.Tensor,
  176. text: torch.Tensor,
  177. text_lengths: torch.Tensor,
  178. ) -> Dict[str, torch.Tensor]:
  179. if self.extract_feats_in_collect_stats:
  180. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  181. else:
  182. # Generate dummy stats if extract_feats_in_collect_stats is False
  183. logging.warning(
  184. "Generating dummy stats for feats and feats_lengths, "
  185. "because encoder_conf.extract_feats_in_collect_stats is "
  186. f"{self.extract_feats_in_collect_stats}"
  187. )
  188. feats, feats_lengths = speech, speech_lengths
  189. return {"feats": feats, "feats_lengths": feats_lengths}
  190. def encode(
  191. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  192. ) -> Tuple[torch.Tensor, torch.Tensor]:
  193. """Frontend + Encoder. Note that this method is used by asr_inference.py
  194. Args:
  195. speech: (Batch, Length, ...)
  196. speech_lengths: (Batch, )
  197. """
  198. with autocast(False):
  199. # 1. Extract feats
  200. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  201. # 2. Data augmentation
  202. if self.specaug is not None and self.training:
  203. feats, feats_lengths = self.specaug(feats, feats_lengths)
  204. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  205. if self.normalize is not None:
  206. feats, feats_lengths = self.normalize(feats, feats_lengths)
  207. # Pre-encoder, e.g. used for raw input data
  208. if self.preencoder is not None:
  209. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  210. # 4. Forward encoder
  211. # feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2)
  212. encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
  213. # Post-encoder, e.g. NLU
  214. if self.postencoder is not None:
  215. encoder_out, encoder_out_lens = self.postencoder(
  216. encoder_out, encoder_out_lens
  217. )
  218. return encoder_out, encoder_out_lens
  219. def _extract_feats(
  220. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  221. ) -> Tuple[torch.Tensor, torch.Tensor]:
  222. assert speech_lengths.dim() == 1, speech_lengths.shape
  223. # for data-parallel
  224. speech = speech[:, : speech_lengths.max()]
  225. if self.frontend is not None:
  226. # Frontend
  227. # e.g. STFT and Feature extract
  228. # data_loader may send time-domain signal in this case
  229. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  230. feats, feats_lengths = self.frontend(speech, speech_lengths)
  231. else:
  232. # No frontend and no feature extract
  233. feats, feats_lengths = speech, speech_lengths
  234. return feats, feats_lengths