e2e_sv.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from typeguard import check_argument_types
  11. from funasr.layers.abs_normalize import AbsNormalize
  12. from funasr.losses.label_smoothing_loss import (
  13. LabelSmoothingLoss, # noqa: H301
  14. )
  15. from funasr.models.ctc import CTC
  16. from funasr.models.decoder.abs_decoder import AbsDecoder
  17. from funasr.models.encoder.abs_encoder import AbsEncoder
  18. from funasr.models.frontend.abs_frontend import AbsFrontend
  19. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  20. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  21. from funasr.models.specaug.abs_specaug import AbsSpecAug
  22. from funasr.modules.add_sos_eos import add_sos_eos
  23. from funasr.modules.e2e_asr_common import ErrorCalculator
  24. from funasr.modules.nets_utils import th_accuracy
  25. from funasr.torch_utils.device_funcs import force_gatherable
  26. from funasr.train.abs_espnet_model import AbsESPnetModel
  27. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  28. from torch.cuda.amp import autocast
  29. else:
  30. # Nothing to do if torch<1.6.0
  31. @contextmanager
  32. def autocast(enabled=True):
  33. yield
  34. class ESPnetSVModel(AbsESPnetModel):
  35. """CTC-attention hybrid Encoder-Decoder model"""
  36. def __init__(
  37. self,
  38. vocab_size: int,
  39. token_list: Union[Tuple[str, ...], List[str]],
  40. frontend: Optional[AbsFrontend],
  41. specaug: Optional[AbsSpecAug],
  42. normalize: Optional[AbsNormalize],
  43. preencoder: Optional[AbsPreEncoder],
  44. encoder: AbsEncoder,
  45. postencoder: Optional[AbsPostEncoder],
  46. pooling_layer: torch.nn.Module,
  47. decoder: AbsDecoder,
  48. ):
  49. assert check_argument_types()
  50. super().__init__()
  51. # note that eos is the same as sos (equivalent ID)
  52. self.vocab_size = vocab_size
  53. self.token_list = token_list.copy()
  54. self.frontend = frontend
  55. self.specaug = specaug
  56. self.normalize = normalize
  57. self.preencoder = preencoder
  58. self.postencoder = postencoder
  59. self.encoder = encoder
  60. self.pooling_layer = pooling_layer
  61. self.decoder = decoder
  62. def forward(
  63. self,
  64. speech: torch.Tensor,
  65. speech_lengths: torch.Tensor,
  66. text: torch.Tensor,
  67. text_lengths: torch.Tensor,
  68. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  69. """Frontend + Encoder + Decoder + Calc loss
  70. Args:
  71. speech: (Batch, Length, ...)
  72. speech_lengths: (Batch, )
  73. text: (Batch, Length)
  74. text_lengths: (Batch,)
  75. """
  76. assert text_lengths.dim() == 1, text_lengths.shape
  77. # Check that batch_size is unified
  78. assert (
  79. speech.shape[0]
  80. == speech_lengths.shape[0]
  81. == text.shape[0]
  82. == text_lengths.shape[0]
  83. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  84. batch_size = speech.shape[0]
  85. # for data-parallel
  86. text = text[:, : text_lengths.max()]
  87. # 1. Encoder
  88. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  89. intermediate_outs = None
  90. if isinstance(encoder_out, tuple):
  91. intermediate_outs = encoder_out[1]
  92. encoder_out = encoder_out[0]
  93. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  94. loss_ctc, cer_ctc = None, None
  95. loss_transducer, cer_transducer, wer_transducer = None, None, None
  96. stats = dict()
  97. # 1. CTC branch
  98. if self.ctc_weight != 0.0:
  99. loss_ctc, cer_ctc = self._calc_ctc_loss(
  100. encoder_out, encoder_out_lens, text, text_lengths
  101. )
  102. # Collect CTC branch stats
  103. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  104. stats["cer_ctc"] = cer_ctc
  105. # Intermediate CTC (optional)
  106. loss_interctc = 0.0
  107. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  108. for layer_idx, intermediate_out in intermediate_outs:
  109. # we assume intermediate_out has the same length & padding
  110. # as those of encoder_out
  111. loss_ic, cer_ic = self._calc_ctc_loss(
  112. intermediate_out, encoder_out_lens, text, text_lengths
  113. )
  114. loss_interctc = loss_interctc + loss_ic
  115. # Collect Intermedaite CTC stats
  116. stats["loss_interctc_layer{}".format(layer_idx)] = (
  117. loss_ic.detach() if loss_ic is not None else None
  118. )
  119. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  120. loss_interctc = loss_interctc / len(intermediate_outs)
  121. # calculate whole encoder loss
  122. loss_ctc = (
  123. 1 - self.interctc_weight
  124. ) * loss_ctc + self.interctc_weight * loss_interctc
  125. if self.use_transducer_decoder:
  126. # 2a. Transducer decoder branch
  127. (
  128. loss_transducer,
  129. cer_transducer,
  130. wer_transducer,
  131. ) = self._calc_transducer_loss(
  132. encoder_out,
  133. encoder_out_lens,
  134. text,
  135. )
  136. if loss_ctc is not None:
  137. loss = loss_transducer + (self.ctc_weight * loss_ctc)
  138. else:
  139. loss = loss_transducer
  140. # Collect Transducer branch stats
  141. stats["loss_transducer"] = (
  142. loss_transducer.detach() if loss_transducer is not None else None
  143. )
  144. stats["cer_transducer"] = cer_transducer
  145. stats["wer_transducer"] = wer_transducer
  146. else:
  147. # 2b. Attention decoder branch
  148. if self.ctc_weight != 1.0:
  149. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  150. encoder_out, encoder_out_lens, text, text_lengths
  151. )
  152. # 3. CTC-Att loss definition
  153. if self.ctc_weight == 0.0:
  154. loss = loss_att
  155. elif self.ctc_weight == 1.0:
  156. loss = loss_ctc
  157. else:
  158. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  159. # Collect Attn branch stats
  160. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  161. stats["acc"] = acc_att
  162. stats["cer"] = cer_att
  163. stats["wer"] = wer_att
  164. # Collect total loss stats
  165. stats["loss"] = torch.clone(loss.detach())
  166. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  167. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  168. return loss, stats, weight
  169. def collect_feats(
  170. self,
  171. speech: torch.Tensor,
  172. speech_lengths: torch.Tensor,
  173. text: torch.Tensor,
  174. text_lengths: torch.Tensor,
  175. ) -> Dict[str, torch.Tensor]:
  176. if self.extract_feats_in_collect_stats:
  177. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  178. else:
  179. # Generate dummy stats if extract_feats_in_collect_stats is False
  180. logging.warning(
  181. "Generating dummy stats for feats and feats_lengths, "
  182. "because encoder_conf.extract_feats_in_collect_stats is "
  183. f"{self.extract_feats_in_collect_stats}"
  184. )
  185. feats, feats_lengths = speech, speech_lengths
  186. return {"feats": feats, "feats_lengths": feats_lengths}
  187. def encode(
  188. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  189. ) -> Tuple[torch.Tensor, torch.Tensor]:
  190. """Frontend + Encoder. Note that this method is used by asr_inference.py
  191. Args:
  192. speech: (Batch, Length, ...)
  193. speech_lengths: (Batch, )
  194. """
  195. with autocast(False):
  196. # 1. Extract feats
  197. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  198. # 2. Data augmentation
  199. if self.specaug is not None and self.training:
  200. feats, feats_lengths = self.specaug(feats, feats_lengths)
  201. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  202. if self.normalize is not None:
  203. feats, feats_lengths = self.normalize(feats, feats_lengths)
  204. # Pre-encoder, e.g. used for raw input data
  205. if self.preencoder is not None:
  206. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  207. # 4. Forward encoder
  208. # feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2)
  209. encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
  210. # Post-encoder, e.g. NLU
  211. if self.postencoder is not None:
  212. encoder_out, encoder_out_lens = self.postencoder(
  213. encoder_out, encoder_out_lens
  214. )
  215. return encoder_out, encoder_out_lens
  216. def _extract_feats(
  217. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  218. ) -> Tuple[torch.Tensor, torch.Tensor]:
  219. assert speech_lengths.dim() == 1, speech_lengths.shape
  220. # for data-parallel
  221. speech = speech[:, : speech_lengths.max()]
  222. if self.frontend is not None:
  223. # Frontend
  224. # e.g. STFT and Feature extract
  225. # data_loader may send time-domain signal in this case
  226. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  227. feats, feats_lengths = self.frontend(speech, speech_lengths)
  228. else:
  229. # No frontend and no feature extract
  230. feats, feats_lengths = speech, speech_lengths
  231. return feats, feats_lengths