e2e_asr_mfcca.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from contextlib import contextmanager
  2. from distutils.version import LooseVersion
  3. from typing import Dict
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. from typing import Union
  8. import logging
  9. import torch
  10. from typeguard import check_argument_types
  11. from funasr.modules.e2e_asr_common import ErrorCalculator
  12. from funasr.modules.nets_utils import th_accuracy
  13. from funasr.modules.add_sos_eos import add_sos_eos
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  22. from funasr.models.specaug.abs_specaug import AbsSpecAug
  23. from funasr.layers.abs_normalize import AbsNormalize
  24. from funasr.torch_utils.device_funcs import force_gatherable
  25. from funasr.train.abs_espnet_model import AbsESPnetModel
  26. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  27. from torch.cuda.amp import autocast
  28. else:
  29. # Nothing to do if torch<1.6.0
  30. @contextmanager
  31. def autocast(enabled=True):
  32. yield
  33. import pdb
  34. import random
  35. import math
  36. class MFCCA(AbsESPnetModel):
  37. """CTC-attention hybrid Encoder-Decoder model"""
  38. def __init__(
  39. self,
  40. vocab_size: int,
  41. token_list: Union[Tuple[str, ...], List[str]],
  42. frontend: Optional[AbsFrontend],
  43. specaug: Optional[AbsSpecAug],
  44. normalize: Optional[AbsNormalize],
  45. preencoder: Optional[AbsPreEncoder],
  46. encoder: AbsEncoder,
  47. decoder: AbsDecoder,
  48. ctc: CTC,
  49. rnnt_decoder: None,
  50. ctc_weight: float = 0.5,
  51. ignore_id: int = -1,
  52. lsm_weight: float = 0.0,
  53. mask_ratio: float = 0.0,
  54. length_normalized_loss: bool = False,
  55. report_cer: bool = True,
  56. report_wer: bool = True,
  57. sym_space: str = "<space>",
  58. sym_blank: str = "<blank>",
  59. ):
  60. assert check_argument_types()
  61. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  62. assert rnnt_decoder is None, "Not implemented"
  63. super().__init__()
  64. # note that eos is the same as sos (equivalent ID)
  65. self.sos = vocab_size - 1
  66. self.eos = vocab_size - 1
  67. self.vocab_size = vocab_size
  68. self.ignore_id = ignore_id
  69. self.ctc_weight = ctc_weight
  70. self.token_list = token_list.copy()
  71. self.mask_ratio = mask_ratio
  72. self.frontend = frontend
  73. self.specaug = specaug
  74. self.normalize = normalize
  75. self.preencoder = preencoder
  76. self.encoder = encoder
  77. # we set self.decoder = None in the CTC mode since
  78. # self.decoder parameters were never used and PyTorch complained
  79. # and threw an Exception in the multi-GPU experiment.
  80. # thanks Jeff Farris for pointing out the issue.
  81. if ctc_weight == 1.0:
  82. self.decoder = None
  83. else:
  84. self.decoder = decoder
  85. if ctc_weight == 0.0:
  86. self.ctc = None
  87. else:
  88. self.ctc = ctc
  89. self.rnnt_decoder = rnnt_decoder
  90. self.criterion_att = LabelSmoothingLoss(
  91. size=vocab_size,
  92. padding_idx=ignore_id,
  93. smoothing=lsm_weight,
  94. normalize_length=length_normalized_loss,
  95. )
  96. if report_cer or report_wer:
  97. self.error_calculator = ErrorCalculator(
  98. token_list, sym_space, sym_blank, report_cer, report_wer
  99. )
  100. else:
  101. self.error_calculator = None
  102. def forward(
  103. self,
  104. speech: torch.Tensor,
  105. speech_lengths: torch.Tensor,
  106. text: torch.Tensor,
  107. text_lengths: torch.Tensor,
  108. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  109. """Frontend + Encoder + Decoder + Calc loss
  110. Args:
  111. speech: (Batch, Length, ...)
  112. speech_lengths: (Batch, )
  113. text: (Batch, Length)
  114. text_lengths: (Batch,)
  115. """
  116. assert text_lengths.dim() == 1, text_lengths.shape
  117. # Check that batch_size is unified
  118. assert (
  119. speech.shape[0]
  120. == speech_lengths.shape[0]
  121. == text.shape[0]
  122. == text_lengths.shape[0]
  123. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  124. #pdb.set_trace()
  125. if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
  126. rate_num = random.random()
  127. #rate_num = 0.1
  128. if(rate_num<=self.mask_ratio):
  129. retain_channel = math.ceil(random.random() *8)
  130. if(retain_channel>1):
  131. speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
  132. else:
  133. speech = speech[:,:,torch.randperm(8)[0]]
  134. #pdb.set_trace()
  135. batch_size = speech.shape[0]
  136. # for data-parallel
  137. text = text[:, : text_lengths.max()]
  138. # 1. Encoder
  139. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  140. # 2a. Attention-decoder branch
  141. if self.ctc_weight == 1.0:
  142. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  143. else:
  144. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  145. encoder_out, encoder_out_lens, text, text_lengths
  146. )
  147. # 2b. CTC branch
  148. if self.ctc_weight == 0.0:
  149. loss_ctc, cer_ctc = None, None
  150. else:
  151. loss_ctc, cer_ctc = self._calc_ctc_loss(
  152. encoder_out, encoder_out_lens, text, text_lengths
  153. )
  154. # 2c. RNN-T branch
  155. if self.rnnt_decoder is not None:
  156. _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
  157. if self.ctc_weight == 0.0:
  158. loss = loss_att
  159. elif self.ctc_weight == 1.0:
  160. loss = loss_ctc
  161. else:
  162. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  163. stats = dict(
  164. loss=loss.detach(),
  165. loss_att=loss_att.detach() if loss_att is not None else None,
  166. loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
  167. acc=acc_att,
  168. cer=cer_att,
  169. wer=wer_att,
  170. cer_ctc=cer_ctc,
  171. )
  172. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  173. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  174. return loss, stats, weight
  175. def collect_feats(
  176. self,
  177. speech: torch.Tensor,
  178. speech_lengths: torch.Tensor,
  179. text: torch.Tensor,
  180. text_lengths: torch.Tensor,
  181. ) -> Dict[str, torch.Tensor]:
  182. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  183. return {"feats": feats, "feats_lengths": feats_lengths}
  184. def encode(
  185. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  186. ) -> Tuple[torch.Tensor, torch.Tensor]:
  187. """Frontend + Encoder. Note that this method is used by asr_inference.py
  188. Args:
  189. speech: (Batch, Length, ...)
  190. speech_lengths: (Batch, )
  191. """
  192. with autocast(False):
  193. # 1. Extract feats
  194. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  195. # 2. Data augmentation
  196. if self.specaug is not None and self.training:
  197. feats, feats_lengths = self.specaug(feats, feats_lengths)
  198. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  199. if self.normalize is not None:
  200. feats, feats_lengths = self.normalize(feats, feats_lengths)
  201. # Pre-encoder, e.g. used for raw input data
  202. if self.preencoder is not None:
  203. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  204. #pdb.set_trace()
  205. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
  206. assert encoder_out.size(0) == speech.size(0), (
  207. encoder_out.size(),
  208. speech.size(0),
  209. )
  210. if(encoder_out.dim()==4):
  211. assert encoder_out.size(2) <= encoder_out_lens.max(), (
  212. encoder_out.size(),
  213. encoder_out_lens.max(),
  214. )
  215. else:
  216. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  217. encoder_out.size(),
  218. encoder_out_lens.max(),
  219. )
  220. return encoder_out, encoder_out_lens
  221. def _extract_feats(
  222. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  223. ) -> Tuple[torch.Tensor, torch.Tensor]:
  224. assert speech_lengths.dim() == 1, speech_lengths.shape
  225. # for data-parallel
  226. speech = speech[:, : speech_lengths.max()]
  227. if self.frontend is not None:
  228. # Frontend
  229. # e.g. STFT and Feature extract
  230. # data_loader may send time-domain signal in this case
  231. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  232. feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
  233. else:
  234. # No frontend and no feature extract
  235. feats, feats_lengths = speech, speech_lengths
  236. channel_size = 1
  237. return feats, feats_lengths, channel_size
  238. def _calc_att_loss(
  239. self,
  240. encoder_out: torch.Tensor,
  241. encoder_out_lens: torch.Tensor,
  242. ys_pad: torch.Tensor,
  243. ys_pad_lens: torch.Tensor,
  244. ):
  245. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  246. ys_in_lens = ys_pad_lens + 1
  247. # 1. Forward decoder
  248. decoder_out, _ = self.decoder(
  249. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  250. )
  251. # 2. Compute attention loss
  252. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  253. acc_att = th_accuracy(
  254. decoder_out.view(-1, self.vocab_size),
  255. ys_out_pad,
  256. ignore_label=self.ignore_id,
  257. )
  258. # Compute cer/wer using attention-decoder
  259. if self.training or self.error_calculator is None:
  260. cer_att, wer_att = None, None
  261. else:
  262. ys_hat = decoder_out.argmax(dim=-1)
  263. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  264. return loss_att, acc_att, cer_att, wer_att
  265. def _calc_ctc_loss(
  266. self,
  267. encoder_out: torch.Tensor,
  268. encoder_out_lens: torch.Tensor,
  269. ys_pad: torch.Tensor,
  270. ys_pad_lens: torch.Tensor,
  271. ):
  272. # Calc CTC loss
  273. if(encoder_out.dim()==4):
  274. encoder_out = encoder_out.mean(1)
  275. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  276. # Calc CER using CTC
  277. cer_ctc = None
  278. if not self.training and self.error_calculator is not None:
  279. ys_hat = self.ctc.argmax(encoder_out).data
  280. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  281. return loss_ctc, cer_ctc
  282. def _calc_rnnt_loss(
  283. self,
  284. encoder_out: torch.Tensor,
  285. encoder_out_lens: torch.Tensor,
  286. ys_pad: torch.Tensor,
  287. ys_pad_lens: torch.Tensor,
  288. ):
  289. raise NotImplementedError