e2e_asr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import logging
  4. from contextlib import contextmanager
  5. from distutils.version import LooseVersion
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Tuple
  10. from typing import Union
  11. import torch
  12. from typeguard import check_argument_types
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  19. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  20. from funasr.models.base_model import FunASRModel
  21. from funasr.modules.add_sos_eos import add_sos_eos
  22. from funasr.modules.e2e_asr_common import ErrorCalculator
  23. from funasr.modules.nets_utils import th_accuracy
  24. from funasr.torch_utils.device_funcs import force_gatherable
  25. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  26. from torch.cuda.amp import autocast
  27. else:
  28. # Nothing to do if torch<1.6.0
  29. @contextmanager
  30. def autocast(enabled=True):
  31. yield
  32. class ESPnetASRModel(FunASRModel):
  33. """CTC-attention hybrid Encoder-Decoder model"""
  34. def __init__(
  35. self,
  36. vocab_size: int,
  37. token_list: Union[Tuple[str, ...], List[str]],
  38. frontend: Optional[torch.nn.Module],
  39. specaug: Optional[torch.nn.Module],
  40. normalize: Optional[torch.nn.Module],
  41. preencoder: Optional[AbsPreEncoder],
  42. encoder: torch.nn.Module,
  43. postencoder: Optional[AbsPostEncoder],
  44. decoder: AbsDecoder,
  45. ctc: CTC,
  46. ctc_weight: float = 0.5,
  47. interctc_weight: float = 0.0,
  48. ignore_id: int = -1,
  49. lsm_weight: float = 0.0,
  50. length_normalized_loss: bool = False,
  51. report_cer: bool = True,
  52. report_wer: bool = True,
  53. sym_space: str = "<space>",
  54. sym_blank: str = "<blank>",
  55. extract_feats_in_collect_stats: bool = True,
  56. ):
  57. assert check_argument_types()
  58. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  59. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  60. super().__init__()
  61. # note that eos is the same as sos (equivalent ID)
  62. self.blank_id = 0
  63. self.sos = 1
  64. self.eos = 2
  65. self.vocab_size = vocab_size
  66. self.ignore_id = ignore_id
  67. self.ctc_weight = ctc_weight
  68. self.interctc_weight = interctc_weight
  69. self.token_list = token_list.copy()
  70. self.frontend = frontend
  71. self.specaug = specaug
  72. self.normalize = normalize
  73. self.preencoder = preencoder
  74. self.postencoder = postencoder
  75. self.encoder = encoder
  76. if not hasattr(self.encoder, "interctc_use_conditioning"):
  77. self.encoder.interctc_use_conditioning = False
  78. if self.encoder.interctc_use_conditioning:
  79. self.encoder.conditioning_layer = torch.nn.Linear(
  80. vocab_size, self.encoder.output_size()
  81. )
  82. self.error_calculator = None
  83. # we set self.decoder = None in the CTC mode since
  84. # self.decoder parameters were never used and PyTorch complained
  85. # and threw an Exception in the multi-GPU experiment.
  86. # thanks Jeff Farris for pointing out the issue.
  87. if ctc_weight == 1.0:
  88. self.decoder = None
  89. else:
  90. self.decoder = decoder
  91. self.criterion_att = LabelSmoothingLoss(
  92. size=vocab_size,
  93. padding_idx=ignore_id,
  94. smoothing=lsm_weight,
  95. normalize_length=length_normalized_loss,
  96. )
  97. if report_cer or report_wer:
  98. self.error_calculator = ErrorCalculator(
  99. token_list, sym_space, sym_blank, report_cer, report_wer
  100. )
  101. if ctc_weight == 0.0:
  102. self.ctc = None
  103. else:
  104. self.ctc = ctc
  105. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  106. def forward(
  107. self,
  108. speech: torch.Tensor,
  109. speech_lengths: torch.Tensor,
  110. text: torch.Tensor,
  111. text_lengths: torch.Tensor,
  112. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  113. """Frontend + Encoder + Decoder + Calc loss
  114. Args:
  115. speech: (Batch, Length, ...)
  116. speech_lengths: (Batch, )
  117. text: (Batch, Length)
  118. text_lengths: (Batch,)
  119. """
  120. assert text_lengths.dim() == 1, text_lengths.shape
  121. # Check that batch_size is unified
  122. assert (
  123. speech.shape[0]
  124. == speech_lengths.shape[0]
  125. == text.shape[0]
  126. == text_lengths.shape[0]
  127. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  128. batch_size = speech.shape[0]
  129. # for data-parallel
  130. text = text[:, : text_lengths.max()]
  131. # 1. Encoder
  132. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  133. intermediate_outs = None
  134. if isinstance(encoder_out, tuple):
  135. intermediate_outs = encoder_out[1]
  136. encoder_out = encoder_out[0]
  137. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  138. loss_ctc, cer_ctc = None, None
  139. stats = dict()
  140. # 1. CTC branch
  141. if self.ctc_weight != 0.0:
  142. loss_ctc, cer_ctc = self._calc_ctc_loss(
  143. encoder_out, encoder_out_lens, text, text_lengths
  144. )
  145. # Collect CTC branch stats
  146. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  147. stats["cer_ctc"] = cer_ctc
  148. # Intermediate CTC (optional)
  149. loss_interctc = 0.0
  150. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  151. for layer_idx, intermediate_out in intermediate_outs:
  152. # we assume intermediate_out has the same length & padding
  153. # as those of encoder_out
  154. loss_ic, cer_ic = self._calc_ctc_loss(
  155. intermediate_out, encoder_out_lens, text, text_lengths
  156. )
  157. loss_interctc = loss_interctc + loss_ic
  158. # Collect Intermedaite CTC stats
  159. stats["loss_interctc_layer{}".format(layer_idx)] = (
  160. loss_ic.detach() if loss_ic is not None else None
  161. )
  162. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  163. loss_interctc = loss_interctc / len(intermediate_outs)
  164. # calculate whole encoder loss
  165. loss_ctc = (
  166. 1 - self.interctc_weight
  167. ) * loss_ctc + self.interctc_weight * loss_interctc
  168. # 2b. Attention decoder branch
  169. if self.ctc_weight != 1.0:
  170. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  171. encoder_out, encoder_out_lens, text, text_lengths
  172. )
  173. # 3. CTC-Att loss definition
  174. if self.ctc_weight == 0.0:
  175. loss = loss_att
  176. elif self.ctc_weight == 1.0:
  177. loss = loss_ctc
  178. else:
  179. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  180. # Collect Attn branch stats
  181. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  182. stats["acc"] = acc_att
  183. stats["cer"] = cer_att
  184. stats["wer"] = wer_att
  185. # Collect total loss stats
  186. stats["loss"] = torch.clone(loss.detach())
  187. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  188. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  189. return loss, stats, weight
  190. def collect_feats(
  191. self,
  192. speech: torch.Tensor,
  193. speech_lengths: torch.Tensor,
  194. text: torch.Tensor,
  195. text_lengths: torch.Tensor,
  196. ) -> Dict[str, torch.Tensor]:
  197. if self.extract_feats_in_collect_stats:
  198. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  199. else:
  200. # Generate dummy stats if extract_feats_in_collect_stats is False
  201. logging.warning(
  202. "Generating dummy stats for feats and feats_lengths, "
  203. "because encoder_conf.extract_feats_in_collect_stats is "
  204. f"{self.extract_feats_in_collect_stats}"
  205. )
  206. feats, feats_lengths = speech, speech_lengths
  207. return {"feats": feats, "feats_lengths": feats_lengths}
  208. def encode(
  209. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  210. ) -> Tuple[torch.Tensor, torch.Tensor]:
  211. """Frontend + Encoder. Note that this method is used by asr_inference.py
  212. Args:
  213. speech: (Batch, Length, ...)
  214. speech_lengths: (Batch, )
  215. """
  216. with autocast(False):
  217. # 1. Extract feats
  218. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  219. # 2. Data augmentation
  220. if self.specaug is not None and self.training:
  221. feats, feats_lengths = self.specaug(feats, feats_lengths)
  222. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  223. if self.normalize is not None:
  224. feats, feats_lengths = self.normalize(feats, feats_lengths)
  225. # Pre-encoder, e.g. used for raw input data
  226. if self.preencoder is not None:
  227. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  228. # 4. Forward encoder
  229. # feats: (Batch, Length, Dim)
  230. # -> encoder_out: (Batch, Length2, Dim2)
  231. if self.encoder.interctc_use_conditioning:
  232. encoder_out, encoder_out_lens, _ = self.encoder(
  233. feats, feats_lengths, ctc=self.ctc
  234. )
  235. else:
  236. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  237. intermediate_outs = None
  238. if isinstance(encoder_out, tuple):
  239. intermediate_outs = encoder_out[1]
  240. encoder_out = encoder_out[0]
  241. # Post-encoder, e.g. NLU
  242. if self.postencoder is not None:
  243. encoder_out, encoder_out_lens = self.postencoder(
  244. encoder_out, encoder_out_lens
  245. )
  246. assert encoder_out.size(0) == speech.size(0), (
  247. encoder_out.size(),
  248. speech.size(0),
  249. )
  250. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  251. encoder_out.size(),
  252. encoder_out_lens.max(),
  253. )
  254. if intermediate_outs is not None:
  255. return (encoder_out, intermediate_outs), encoder_out_lens
  256. return encoder_out, encoder_out_lens
  257. def _extract_feats(
  258. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  259. ) -> Tuple[torch.Tensor, torch.Tensor]:
  260. assert speech_lengths.dim() == 1, speech_lengths.shape
  261. # for data-parallel
  262. speech = speech[:, : speech_lengths.max()]
  263. if self.frontend is not None:
  264. # Frontend
  265. # e.g. STFT and Feature extract
  266. # data_loader may send time-domain signal in this case
  267. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  268. feats, feats_lengths = self.frontend(speech, speech_lengths)
  269. else:
  270. # No frontend and no feature extract
  271. feats, feats_lengths = speech, speech_lengths
  272. return feats, feats_lengths
  273. def nll(
  274. self,
  275. encoder_out: torch.Tensor,
  276. encoder_out_lens: torch.Tensor,
  277. ys_pad: torch.Tensor,
  278. ys_pad_lens: torch.Tensor,
  279. ) -> torch.Tensor:
  280. """Compute negative log likelihood(nll) from transformer-decoder
  281. Normally, this function is called in batchify_nll.
  282. Args:
  283. encoder_out: (Batch, Length, Dim)
  284. encoder_out_lens: (Batch,)
  285. ys_pad: (Batch, Length)
  286. ys_pad_lens: (Batch,)
  287. """
  288. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  289. ys_in_lens = ys_pad_lens + 1
  290. # 1. Forward decoder
  291. decoder_out, _ = self.decoder(
  292. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  293. ) # [batch, seqlen, dim]
  294. batch_size = decoder_out.size(0)
  295. decoder_num_class = decoder_out.size(2)
  296. # nll: negative log-likelihood
  297. nll = torch.nn.functional.cross_entropy(
  298. decoder_out.view(-1, decoder_num_class),
  299. ys_out_pad.view(-1),
  300. ignore_index=self.ignore_id,
  301. reduction="none",
  302. )
  303. nll = nll.view(batch_size, -1)
  304. nll = nll.sum(dim=1)
  305. assert nll.size(0) == batch_size
  306. return nll
  307. def batchify_nll(
  308. self,
  309. encoder_out: torch.Tensor,
  310. encoder_out_lens: torch.Tensor,
  311. ys_pad: torch.Tensor,
  312. ys_pad_lens: torch.Tensor,
  313. batch_size: int = 100,
  314. ):
  315. """Compute negative log likelihood(nll) from transformer-decoder
  316. To avoid OOM, this fuction seperate the input into batches.
  317. Then call nll for each batch and combine and return results.
  318. Args:
  319. encoder_out: (Batch, Length, Dim)
  320. encoder_out_lens: (Batch,)
  321. ys_pad: (Batch, Length)
  322. ys_pad_lens: (Batch,)
  323. batch_size: int, samples each batch contain when computing nll,
  324. you may change this to avoid OOM or increase
  325. GPU memory usage
  326. """
  327. total_num = encoder_out.size(0)
  328. if total_num <= batch_size:
  329. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  330. else:
  331. nll = []
  332. start_idx = 0
  333. while True:
  334. end_idx = min(start_idx + batch_size, total_num)
  335. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  336. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  337. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  338. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  339. batch_nll = self.nll(
  340. batch_encoder_out,
  341. batch_encoder_out_lens,
  342. batch_ys_pad,
  343. batch_ys_pad_lens,
  344. )
  345. nll.append(batch_nll)
  346. start_idx = end_idx
  347. if start_idx == total_num:
  348. break
  349. nll = torch.cat(nll)
  350. assert nll.size(0) == total_num
  351. return nll
  352. def _calc_att_loss(
  353. self,
  354. encoder_out: torch.Tensor,
  355. encoder_out_lens: torch.Tensor,
  356. ys_pad: torch.Tensor,
  357. ys_pad_lens: torch.Tensor,
  358. ):
  359. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  360. ys_in_lens = ys_pad_lens + 1
  361. # 1. Forward decoder
  362. decoder_out, _ = self.decoder(
  363. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  364. )
  365. # 2. Compute attention loss
  366. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  367. acc_att = th_accuracy(
  368. decoder_out.view(-1, self.vocab_size),
  369. ys_out_pad,
  370. ignore_label=self.ignore_id,
  371. )
  372. # Compute cer/wer using attention-decoder
  373. if self.training or self.error_calculator is None:
  374. cer_att, wer_att = None, None
  375. else:
  376. ys_hat = decoder_out.argmax(dim=-1)
  377. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  378. return loss_att, acc_att, cer_att, wer_att
  379. def _calc_ctc_loss(
  380. self,
  381. encoder_out: torch.Tensor,
  382. encoder_out_lens: torch.Tensor,
  383. ys_pad: torch.Tensor,
  384. ys_pad_lens: torch.Tensor,
  385. ):
  386. # Calc CTC loss
  387. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  388. # Calc CER using CTC
  389. cer_ctc = None
  390. if not self.training and self.error_calculator is not None:
  391. ys_hat = self.ctc.argmax(encoder_out).data
  392. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  393. return loss_ctc, cer_ctc