e2e_asr_mfcca.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. from contextlib import contextmanager
  2. from distutils.version import LooseVersion
  3. from typing import Dict
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. from typing import Union
  8. import logging
  9. import torch
  10. from typeguard import check_argument_types
  11. from funasr.modules.e2e_asr_common import ErrorCalculator
  12. from funasr.modules.nets_utils import th_accuracy
  13. from funasr.modules.add_sos_eos import add_sos_eos
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  20. from funasr.models.base_model import FunASRModel
  21. from funasr.torch_utils.device_funcs import force_gatherable
  22. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  23. from torch.cuda.amp import autocast
  24. else:
  25. # Nothing to do if torch<1.6.0
  26. @contextmanager
  27. def autocast(enabled=True):
  28. yield
  29. import pdb
  30. import random
  31. import math
  32. class MFCCA(FunASRModel):
  33. """CTC-attention hybrid Encoder-Decoder model"""
  34. def __init__(
  35. self,
  36. vocab_size: int,
  37. token_list: Union[Tuple[str, ...], List[str]],
  38. frontend: Optional[torch.nn.Module],
  39. specaug: Optional[torch.nn.Module],
  40. normalize: Optional[torch.nn.Module],
  41. preencoder: Optional[AbsPreEncoder],
  42. encoder: torch.nn.Module,
  43. decoder: AbsDecoder,
  44. ctc: CTC,
  45. rnnt_decoder: None,
  46. ctc_weight: float = 0.5,
  47. ignore_id: int = -1,
  48. lsm_weight: float = 0.0,
  49. mask_ratio: float = 0.0,
  50. length_normalized_loss: bool = False,
  51. report_cer: bool = True,
  52. report_wer: bool = True,
  53. sym_space: str = "<space>",
  54. sym_blank: str = "<blank>",
  55. ):
  56. assert check_argument_types()
  57. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  58. assert rnnt_decoder is None, "Not implemented"
  59. super().__init__()
  60. # note that eos is the same as sos (equivalent ID)
  61. self.sos = vocab_size - 1
  62. self.eos = vocab_size - 1
  63. self.vocab_size = vocab_size
  64. self.ignore_id = ignore_id
  65. self.ctc_weight = ctc_weight
  66. self.token_list = token_list.copy()
  67. self.mask_ratio = mask_ratio
  68. self.frontend = frontend
  69. self.specaug = specaug
  70. self.normalize = normalize
  71. self.preencoder = preencoder
  72. self.encoder = encoder
  73. # we set self.decoder = None in the CTC mode since
  74. # self.decoder parameters were never used and PyTorch complained
  75. # and threw an Exception in the multi-GPU experiment.
  76. # thanks Jeff Farris for pointing out the issue.
  77. if ctc_weight == 1.0:
  78. self.decoder = None
  79. else:
  80. self.decoder = decoder
  81. if ctc_weight == 0.0:
  82. self.ctc = None
  83. else:
  84. self.ctc = ctc
  85. self.rnnt_decoder = rnnt_decoder
  86. self.criterion_att = LabelSmoothingLoss(
  87. size=vocab_size,
  88. padding_idx=ignore_id,
  89. smoothing=lsm_weight,
  90. normalize_length=length_normalized_loss,
  91. )
  92. if report_cer or report_wer:
  93. self.error_calculator = ErrorCalculator(
  94. token_list, sym_space, sym_blank, report_cer, report_wer
  95. )
  96. else:
  97. self.error_calculator = None
  98. def forward(
  99. self,
  100. speech: torch.Tensor,
  101. speech_lengths: torch.Tensor,
  102. text: torch.Tensor,
  103. text_lengths: torch.Tensor,
  104. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  105. """Frontend + Encoder + Decoder + Calc loss
  106. Args:
  107. speech: (Batch, Length, ...)
  108. speech_lengths: (Batch, )
  109. text: (Batch, Length)
  110. text_lengths: (Batch,)
  111. """
  112. assert text_lengths.dim() == 1, text_lengths.shape
  113. # Check that batch_size is unified
  114. assert (
  115. speech.shape[0]
  116. == speech_lengths.shape[0]
  117. == text.shape[0]
  118. == text_lengths.shape[0]
  119. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  120. #pdb.set_trace()
  121. if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
  122. rate_num = random.random()
  123. #rate_num = 0.1
  124. if(rate_num<=self.mask_ratio):
  125. retain_channel = math.ceil(random.random() *8)
  126. if(retain_channel>1):
  127. speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
  128. else:
  129. speech = speech[:,:,torch.randperm(8)[0]]
  130. #pdb.set_trace()
  131. batch_size = speech.shape[0]
  132. # for data-parallel
  133. text = text[:, : text_lengths.max()]
  134. # 1. Encoder
  135. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  136. # 2a. Attention-decoder branch
  137. if self.ctc_weight == 1.0:
  138. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  139. else:
  140. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  141. encoder_out, encoder_out_lens, text, text_lengths
  142. )
  143. # 2b. CTC branch
  144. if self.ctc_weight == 0.0:
  145. loss_ctc, cer_ctc = None, None
  146. else:
  147. loss_ctc, cer_ctc = self._calc_ctc_loss(
  148. encoder_out, encoder_out_lens, text, text_lengths
  149. )
  150. # 2c. RNN-T branch
  151. if self.rnnt_decoder is not None:
  152. _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
  153. if self.ctc_weight == 0.0:
  154. loss = loss_att
  155. elif self.ctc_weight == 1.0:
  156. loss = loss_ctc
  157. else:
  158. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  159. stats = dict(
  160. loss=loss.detach(),
  161. loss_att=loss_att.detach() if loss_att is not None else None,
  162. loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
  163. acc=acc_att,
  164. cer=cer_att,
  165. wer=wer_att,
  166. cer_ctc=cer_ctc,
  167. )
  168. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  169. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  170. return loss, stats, weight
  171. def collect_feats(
  172. self,
  173. speech: torch.Tensor,
  174. speech_lengths: torch.Tensor,
  175. text: torch.Tensor,
  176. text_lengths: torch.Tensor,
  177. ) -> Dict[str, torch.Tensor]:
  178. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  179. return {"feats": feats, "feats_lengths": feats_lengths}
  180. def encode(
  181. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  182. ) -> Tuple[torch.Tensor, torch.Tensor]:
  183. """Frontend + Encoder. Note that this method is used by asr_inference.py
  184. Args:
  185. speech: (Batch, Length, ...)
  186. speech_lengths: (Batch, )
  187. """
  188. with autocast(False):
  189. # 1. Extract feats
  190. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  191. # 2. Data augmentation
  192. if self.specaug is not None and self.training:
  193. feats, feats_lengths = self.specaug(feats, feats_lengths)
  194. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  195. if self.normalize is not None:
  196. feats, feats_lengths = self.normalize(feats, feats_lengths)
  197. # Pre-encoder, e.g. used for raw input data
  198. if self.preencoder is not None:
  199. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  200. #pdb.set_trace()
  201. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
  202. assert encoder_out.size(0) == speech.size(0), (
  203. encoder_out.size(),
  204. speech.size(0),
  205. )
  206. if(encoder_out.dim()==4):
  207. assert encoder_out.size(2) <= encoder_out_lens.max(), (
  208. encoder_out.size(),
  209. encoder_out_lens.max(),
  210. )
  211. else:
  212. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  213. encoder_out.size(),
  214. encoder_out_lens.max(),
  215. )
  216. return encoder_out, encoder_out_lens
  217. def _extract_feats(
  218. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  219. ) -> Tuple[torch.Tensor, torch.Tensor]:
  220. assert speech_lengths.dim() == 1, speech_lengths.shape
  221. # for data-parallel
  222. speech = speech[:, : speech_lengths.max()]
  223. if self.frontend is not None:
  224. # Frontend
  225. # e.g. STFT and Feature extract
  226. # data_loader may send time-domain signal in this case
  227. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  228. feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
  229. else:
  230. # No frontend and no feature extract
  231. feats, feats_lengths = speech, speech_lengths
  232. channel_size = 1
  233. return feats, feats_lengths, channel_size
  234. def _calc_att_loss(
  235. self,
  236. encoder_out: torch.Tensor,
  237. encoder_out_lens: torch.Tensor,
  238. ys_pad: torch.Tensor,
  239. ys_pad_lens: torch.Tensor,
  240. ):
  241. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  242. ys_in_lens = ys_pad_lens + 1
  243. # 1. Forward decoder
  244. decoder_out, _ = self.decoder(
  245. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  246. )
  247. # 2. Compute attention loss
  248. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  249. acc_att = th_accuracy(
  250. decoder_out.view(-1, self.vocab_size),
  251. ys_out_pad,
  252. ignore_label=self.ignore_id,
  253. )
  254. # Compute cer/wer using attention-decoder
  255. if self.training or self.error_calculator is None:
  256. cer_att, wer_att = None, None
  257. else:
  258. ys_hat = decoder_out.argmax(dim=-1)
  259. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  260. return loss_att, acc_att, cer_att, wer_att
  261. def _calc_ctc_loss(
  262. self,
  263. encoder_out: torch.Tensor,
  264. encoder_out_lens: torch.Tensor,
  265. ys_pad: torch.Tensor,
  266. ys_pad_lens: torch.Tensor,
  267. ):
  268. # Calc CTC loss
  269. if(encoder_out.dim()==4):
  270. encoder_out = encoder_out.mean(1)
  271. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  272. # Calc CER using CTC
  273. cer_ctc = None
  274. if not self.training and self.error_calculator is not None:
  275. ys_hat = self.ctc.argmax(encoder_out).data
  276. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  277. return loss_ctc, cer_ctc
  278. def _calc_rnnt_loss(
  279. self,
  280. encoder_out: torch.Tensor,
  281. encoder_out_lens: torch.Tensor,
  282. ys_pad: torch.Tensor,
  283. ys_pad_lens: torch.Tensor,
  284. ):
  285. raise NotImplementedError