e2e_asr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import logging
  4. from contextlib import contextmanager
  5. from distutils.version import LooseVersion
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Tuple
  10. from typing import Union
  11. import torch
  12. from funasr.layers.abs_normalize import AbsNormalize
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.encoder.abs_encoder import AbsEncoder
  19. from funasr.models.frontend.abs_frontend import AbsFrontend
  20. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  21. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  22. from funasr.models.specaug.abs_specaug import AbsSpecAug
  23. from funasr.modules.add_sos_eos import add_sos_eos
  24. from funasr.modules.e2e_asr_common import ErrorCalculator
  25. from funasr.modules.nets_utils import th_accuracy
  26. from funasr.torch_utils.device_funcs import force_gatherable
  27. from funasr.models.base_model import FunASRModel
  28. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  29. from torch.cuda.amp import autocast
  30. else:
  31. # Nothing to do if torch<1.6.0
  32. @contextmanager
  33. def autocast(enabled=True):
  34. yield
  35. class ASRModel(FunASRModel):
  36. """CTC-attention hybrid Encoder-Decoder model"""
  37. def __init__(
  38. self,
  39. vocab_size: int,
  40. token_list: Union[Tuple[str, ...], List[str]],
  41. frontend: Optional[AbsFrontend],
  42. specaug: Optional[AbsSpecAug],
  43. normalize: Optional[AbsNormalize],
  44. encoder: AbsEncoder,
  45. decoder: AbsDecoder,
  46. ctc: CTC,
  47. ctc_weight: float = 0.5,
  48. interctc_weight: float = 0.0,
  49. ignore_id: int = -1,
  50. lsm_weight: float = 0.0,
  51. length_normalized_loss: bool = False,
  52. report_cer: bool = True,
  53. report_wer: bool = True,
  54. sym_space: str = "<space>",
  55. sym_blank: str = "<blank>",
  56. extract_feats_in_collect_stats: bool = True,
  57. preencoder: Optional[AbsPreEncoder] = None,
  58. postencoder: Optional[AbsPostEncoder] = None,
  59. ):
  60. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  61. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  62. super().__init__()
  63. # note that eos is the same as sos (equivalent ID)
  64. self.blank_id = 0
  65. self.sos = 1
  66. self.eos = 2
  67. self.vocab_size = vocab_size
  68. self.ignore_id = ignore_id
  69. self.ctc_weight = ctc_weight
  70. self.interctc_weight = interctc_weight
  71. self.token_list = token_list.copy()
  72. self.frontend = frontend
  73. self.specaug = specaug
  74. self.normalize = normalize
  75. self.preencoder = preencoder
  76. self.postencoder = postencoder
  77. self.encoder = encoder
  78. if not hasattr(self.encoder, "interctc_use_conditioning"):
  79. self.encoder.interctc_use_conditioning = False
  80. if self.encoder.interctc_use_conditioning:
  81. self.encoder.conditioning_layer = torch.nn.Linear(
  82. vocab_size, self.encoder.output_size()
  83. )
  84. self.error_calculator = None
  85. # we set self.decoder = None in the CTC mode since
  86. # self.decoder parameters were never used and PyTorch complained
  87. # and threw an Exception in the multi-GPU experiment.
  88. # thanks Jeff Farris for pointing out the issue.
  89. if ctc_weight == 1.0:
  90. self.decoder = None
  91. else:
  92. self.decoder = decoder
  93. self.criterion_att = LabelSmoothingLoss(
  94. size=vocab_size,
  95. padding_idx=ignore_id,
  96. smoothing=lsm_weight,
  97. normalize_length=length_normalized_loss,
  98. )
  99. if report_cer or report_wer:
  100. self.error_calculator = ErrorCalculator(
  101. token_list, sym_space, sym_blank, report_cer, report_wer
  102. )
  103. if ctc_weight == 0.0:
  104. self.ctc = None
  105. else:
  106. self.ctc = ctc
  107. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  108. def forward(
  109. self,
  110. speech: torch.Tensor,
  111. speech_lengths: torch.Tensor,
  112. text: torch.Tensor,
  113. text_lengths: torch.Tensor,
  114. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  115. """Frontend + Encoder + Decoder + Calc loss
  116. Args:
  117. speech: (Batch, Length, ...)
  118. speech_lengths: (Batch, )
  119. text: (Batch, Length)
  120. text_lengths: (Batch,)
  121. """
  122. assert text_lengths.dim() == 1, text_lengths.shape
  123. # Check that batch_size is unified
  124. assert (
  125. speech.shape[0]
  126. == speech_lengths.shape[0]
  127. == text.shape[0]
  128. == text_lengths.shape[0]
  129. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  130. batch_size = speech.shape[0]
  131. # for data-parallel
  132. text = text[:, : text_lengths.max()]
  133. # 1. Encoder
  134. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  135. intermediate_outs = None
  136. if isinstance(encoder_out, tuple):
  137. intermediate_outs = encoder_out[1]
  138. encoder_out = encoder_out[0]
  139. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  140. loss_ctc, cer_ctc = None, None
  141. stats = dict()
  142. # 1. CTC branch
  143. if self.ctc_weight != 0.0:
  144. loss_ctc, cer_ctc = self._calc_ctc_loss(
  145. encoder_out, encoder_out_lens, text, text_lengths
  146. )
  147. # Collect CTC branch stats
  148. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  149. stats["cer_ctc"] = cer_ctc
  150. # Intermediate CTC (optional)
  151. loss_interctc = 0.0
  152. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  153. for layer_idx, intermediate_out in intermediate_outs:
  154. # we assume intermediate_out has the same length & padding
  155. # as those of encoder_out
  156. loss_ic, cer_ic = self._calc_ctc_loss(
  157. intermediate_out, encoder_out_lens, text, text_lengths
  158. )
  159. loss_interctc = loss_interctc + loss_ic
  160. # Collect Intermedaite CTC stats
  161. stats["loss_interctc_layer{}".format(layer_idx)] = (
  162. loss_ic.detach() if loss_ic is not None else None
  163. )
  164. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  165. loss_interctc = loss_interctc / len(intermediate_outs)
  166. # calculate whole encoder loss
  167. loss_ctc = (
  168. 1 - self.interctc_weight
  169. ) * loss_ctc + self.interctc_weight * loss_interctc
  170. # 2b. Attention decoder branch
  171. if self.ctc_weight != 1.0:
  172. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  173. encoder_out, encoder_out_lens, text, text_lengths
  174. )
  175. # 3. CTC-Att loss definition
  176. if self.ctc_weight == 0.0:
  177. loss = loss_att
  178. elif self.ctc_weight == 1.0:
  179. loss = loss_ctc
  180. else:
  181. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  182. # Collect Attn branch stats
  183. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  184. stats["acc"] = acc_att
  185. stats["cer"] = cer_att
  186. stats["wer"] = wer_att
  187. # Collect total loss stats
  188. stats["loss"] = torch.clone(loss.detach())
  189. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  190. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  191. return loss, stats, weight
  192. def collect_feats(
  193. self,
  194. speech: torch.Tensor,
  195. speech_lengths: torch.Tensor,
  196. text: torch.Tensor,
  197. text_lengths: torch.Tensor,
  198. ) -> Dict[str, torch.Tensor]:
  199. if self.extract_feats_in_collect_stats:
  200. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  201. else:
  202. # Generate dummy stats if extract_feats_in_collect_stats is False
  203. logging.warning(
  204. "Generating dummy stats for feats and feats_lengths, "
  205. "because encoder_conf.extract_feats_in_collect_stats is "
  206. f"{self.extract_feats_in_collect_stats}"
  207. )
  208. feats, feats_lengths = speech, speech_lengths
  209. return {"feats": feats, "feats_lengths": feats_lengths}
  210. def encode(
  211. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  212. ) -> Tuple[torch.Tensor, torch.Tensor]:
  213. """Frontend + Encoder. Note that this method is used by asr_inference.py
  214. Args:
  215. speech: (Batch, Length, ...)
  216. speech_lengths: (Batch, )
  217. """
  218. with autocast(False):
  219. # 1. Extract feats
  220. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  221. # 2. Data augmentation
  222. if self.specaug is not None and self.training:
  223. feats, feats_lengths = self.specaug(feats, feats_lengths)
  224. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  225. if self.normalize is not None:
  226. feats, feats_lengths = self.normalize(feats, feats_lengths)
  227. # Pre-encoder, e.g. used for raw input data
  228. if self.preencoder is not None:
  229. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  230. # 4. Forward encoder
  231. # feats: (Batch, Length, Dim)
  232. # -> encoder_out: (Batch, Length2, Dim2)
  233. if self.encoder.interctc_use_conditioning:
  234. encoder_out, encoder_out_lens, _ = self.encoder(
  235. feats, feats_lengths, ctc=self.ctc
  236. )
  237. else:
  238. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  239. intermediate_outs = None
  240. if isinstance(encoder_out, tuple):
  241. intermediate_outs = encoder_out[1]
  242. encoder_out = encoder_out[0]
  243. # Post-encoder, e.g. NLU
  244. if self.postencoder is not None:
  245. encoder_out, encoder_out_lens = self.postencoder(
  246. encoder_out, encoder_out_lens
  247. )
  248. assert encoder_out.size(0) == speech.size(0), (
  249. encoder_out.size(),
  250. speech.size(0),
  251. )
  252. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  253. encoder_out.size(),
  254. encoder_out_lens.max(),
  255. )
  256. if intermediate_outs is not None:
  257. return (encoder_out, intermediate_outs), encoder_out_lens
  258. return encoder_out, encoder_out_lens
  259. def _extract_feats(
  260. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  261. ) -> Tuple[torch.Tensor, torch.Tensor]:
  262. assert speech_lengths.dim() == 1, speech_lengths.shape
  263. # for data-parallel
  264. speech = speech[:, : speech_lengths.max()]
  265. if self.frontend is not None:
  266. # Frontend
  267. # e.g. STFT and Feature extract
  268. # data_loader may send time-domain signal in this case
  269. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  270. feats, feats_lengths = self.frontend(speech, speech_lengths)
  271. else:
  272. # No frontend and no feature extract
  273. feats, feats_lengths = speech, speech_lengths
  274. return feats, feats_lengths
  275. def nll(
  276. self,
  277. encoder_out: torch.Tensor,
  278. encoder_out_lens: torch.Tensor,
  279. ys_pad: torch.Tensor,
  280. ys_pad_lens: torch.Tensor,
  281. ) -> torch.Tensor:
  282. """Compute negative log likelihood(nll) from transformer-decoder
  283. Normally, this function is called in batchify_nll.
  284. Args:
  285. encoder_out: (Batch, Length, Dim)
  286. encoder_out_lens: (Batch,)
  287. ys_pad: (Batch, Length)
  288. ys_pad_lens: (Batch,)
  289. """
  290. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  291. ys_in_lens = ys_pad_lens + 1
  292. # 1. Forward decoder
  293. decoder_out, _ = self.decoder(
  294. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  295. ) # [batch, seqlen, dim]
  296. batch_size = decoder_out.size(0)
  297. decoder_num_class = decoder_out.size(2)
  298. # nll: negative log-likelihood
  299. nll = torch.nn.functional.cross_entropy(
  300. decoder_out.view(-1, decoder_num_class),
  301. ys_out_pad.view(-1),
  302. ignore_index=self.ignore_id,
  303. reduction="none",
  304. )
  305. nll = nll.view(batch_size, -1)
  306. nll = nll.sum(dim=1)
  307. assert nll.size(0) == batch_size
  308. return nll
  309. def batchify_nll(
  310. self,
  311. encoder_out: torch.Tensor,
  312. encoder_out_lens: torch.Tensor,
  313. ys_pad: torch.Tensor,
  314. ys_pad_lens: torch.Tensor,
  315. batch_size: int = 100,
  316. ):
  317. """Compute negative log likelihood(nll) from transformer-decoder
  318. To avoid OOM, this fuction seperate the input into batches.
  319. Then call nll for each batch and combine and return results.
  320. Args:
  321. encoder_out: (Batch, Length, Dim)
  322. encoder_out_lens: (Batch,)
  323. ys_pad: (Batch, Length)
  324. ys_pad_lens: (Batch,)
  325. batch_size: int, samples each batch contain when computing nll,
  326. you may change this to avoid OOM or increase
  327. GPU memory usage
  328. """
  329. total_num = encoder_out.size(0)
  330. if total_num <= batch_size:
  331. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  332. else:
  333. nll = []
  334. start_idx = 0
  335. while True:
  336. end_idx = min(start_idx + batch_size, total_num)
  337. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  338. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  339. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  340. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  341. batch_nll = self.nll(
  342. batch_encoder_out,
  343. batch_encoder_out_lens,
  344. batch_ys_pad,
  345. batch_ys_pad_lens,
  346. )
  347. nll.append(batch_nll)
  348. start_idx = end_idx
  349. if start_idx == total_num:
  350. break
  351. nll = torch.cat(nll)
  352. assert nll.size(0) == total_num
  353. return nll
  354. def _calc_att_loss(
  355. self,
  356. encoder_out: torch.Tensor,
  357. encoder_out_lens: torch.Tensor,
  358. ys_pad: torch.Tensor,
  359. ys_pad_lens: torch.Tensor,
  360. ):
  361. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  362. ys_in_lens = ys_pad_lens + 1
  363. # 1. Forward decoder
  364. decoder_out, _ = self.decoder(
  365. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  366. )
  367. # 2. Compute attention loss
  368. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  369. acc_att = th_accuracy(
  370. decoder_out.view(-1, self.vocab_size),
  371. ys_out_pad,
  372. ignore_label=self.ignore_id,
  373. )
  374. # Compute cer/wer using attention-decoder
  375. if self.training or self.error_calculator is None:
  376. cer_att, wer_att = None, None
  377. else:
  378. ys_hat = decoder_out.argmax(dim=-1)
  379. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  380. return loss_att, acc_att, cer_att, wer_att
  381. def _calc_ctc_loss(
  382. self,
  383. encoder_out: torch.Tensor,
  384. encoder_out_lens: torch.Tensor,
  385. ys_pad: torch.Tensor,
  386. ys_pad_lens: torch.Tensor,
  387. ):
  388. # Calc CTC loss
  389. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  390. # Calc CER using CTC
  391. cer_ctc = None
  392. if not self.training and self.error_calculator is not None:
  393. ys_hat = self.ctc.argmax(encoder_out).data
  394. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  395. return loss_ctc, cer_ctc