e2e_asr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import logging
  4. from contextlib import contextmanager
  5. from distutils.version import LooseVersion
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Tuple
  10. from typing import Union
  11. import torch
  12. from typeguard import check_argument_types
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  23. from funasr.models.specaug.abs_specaug import AbsSpecAug
  24. from funasr.modules.add_sos_eos import add_sos_eos
  25. from funasr.modules.e2e_asr_common import ErrorCalculator
  26. from funasr.modules.nets_utils import th_accuracy
  27. from funasr.torch_utils.device_funcs import force_gatherable
  28. from funasr.train.abs_espnet_model import AbsESPnetModel
  29. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  30. from torch.cuda.amp import autocast
  31. else:
  32. # Nothing to do if torch<1.6.0
  33. @contextmanager
  34. def autocast(enabled=True):
  35. yield
  36. class ESPnetASRModel(AbsESPnetModel):
  37. """CTC-attention hybrid Encoder-Decoder model"""
  38. def __init__(
  39. self,
  40. vocab_size: int,
  41. token_list: Union[Tuple[str, ...], List[str]],
  42. frontend: Optional[AbsFrontend],
  43. specaug: Optional[AbsSpecAug],
  44. normalize: Optional[AbsNormalize],
  45. preencoder: Optional[AbsPreEncoder],
  46. encoder: AbsEncoder,
  47. postencoder: Optional[AbsPostEncoder],
  48. decoder: AbsDecoder,
  49. ctc: CTC,
  50. ctc_weight: float = 0.5,
  51. interctc_weight: float = 0.0,
  52. ignore_id: int = -1,
  53. lsm_weight: float = 0.0,
  54. length_normalized_loss: bool = False,
  55. report_cer: bool = True,
  56. report_wer: bool = True,
  57. sym_space: str = "<space>",
  58. sym_blank: str = "<blank>",
  59. extract_feats_in_collect_stats: bool = True,
  60. ):
  61. assert check_argument_types()
  62. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  63. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  64. super().__init__()
  65. # note that eos is the same as sos (equivalent ID)
  66. self.blank_id = 0
  67. self.sos = 1
  68. self.eos = 2
  69. self.vocab_size = vocab_size
  70. self.ignore_id = ignore_id
  71. self.ctc_weight = ctc_weight
  72. self.interctc_weight = interctc_weight
  73. self.token_list = token_list.copy()
  74. self.frontend = frontend
  75. self.specaug = specaug
  76. self.normalize = normalize
  77. self.preencoder = preencoder
  78. self.postencoder = postencoder
  79. self.encoder = encoder
  80. if not hasattr(self.encoder, "interctc_use_conditioning"):
  81. self.encoder.interctc_use_conditioning = False
  82. if self.encoder.interctc_use_conditioning:
  83. self.encoder.conditioning_layer = torch.nn.Linear(
  84. vocab_size, self.encoder.output_size()
  85. )
  86. self.error_calculator = None
  87. # we set self.decoder = None in the CTC mode since
  88. # self.decoder parameters were never used and PyTorch complained
  89. # and threw an Exception in the multi-GPU experiment.
  90. # thanks Jeff Farris for pointing out the issue.
  91. if ctc_weight == 1.0:
  92. self.decoder = None
  93. else:
  94. self.decoder = decoder
  95. self.criterion_att = LabelSmoothingLoss(
  96. size=vocab_size,
  97. padding_idx=ignore_id,
  98. smoothing=lsm_weight,
  99. normalize_length=length_normalized_loss,
  100. )
  101. if report_cer or report_wer:
  102. self.error_calculator = ErrorCalculator(
  103. token_list, sym_space, sym_blank, report_cer, report_wer
  104. )
  105. if ctc_weight == 0.0:
  106. self.ctc = None
  107. else:
  108. self.ctc = ctc
  109. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  110. def forward(
  111. self,
  112. speech: torch.Tensor,
  113. speech_lengths: torch.Tensor,
  114. text: torch.Tensor,
  115. text_lengths: torch.Tensor,
  116. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  117. """Frontend + Encoder + Decoder + Calc loss
  118. Args:
  119. speech: (Batch, Length, ...)
  120. speech_lengths: (Batch, )
  121. text: (Batch, Length)
  122. text_lengths: (Batch,)
  123. """
  124. assert text_lengths.dim() == 1, text_lengths.shape
  125. # Check that batch_size is unified
  126. assert (
  127. speech.shape[0]
  128. == speech_lengths.shape[0]
  129. == text.shape[0]
  130. == text_lengths.shape[0]
  131. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  132. batch_size = speech.shape[0]
  133. # for data-parallel
  134. text = text[:, : text_lengths.max()]
  135. # 1. Encoder
  136. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  137. intermediate_outs = None
  138. if isinstance(encoder_out, tuple):
  139. intermediate_outs = encoder_out[1]
  140. encoder_out = encoder_out[0]
  141. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  142. loss_ctc, cer_ctc = None, None
  143. stats = dict()
  144. # 1. CTC branch
  145. if self.ctc_weight != 0.0:
  146. loss_ctc, cer_ctc = self._calc_ctc_loss(
  147. encoder_out, encoder_out_lens, text, text_lengths
  148. )
  149. # Collect CTC branch stats
  150. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  151. stats["cer_ctc"] = cer_ctc
  152. # Intermediate CTC (optional)
  153. loss_interctc = 0.0
  154. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  155. for layer_idx, intermediate_out in intermediate_outs:
  156. # we assume intermediate_out has the same length & padding
  157. # as those of encoder_out
  158. loss_ic, cer_ic = self._calc_ctc_loss(
  159. intermediate_out, encoder_out_lens, text, text_lengths
  160. )
  161. loss_interctc = loss_interctc + loss_ic
  162. # Collect Intermedaite CTC stats
  163. stats["loss_interctc_layer{}".format(layer_idx)] = (
  164. loss_ic.detach() if loss_ic is not None else None
  165. )
  166. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  167. loss_interctc = loss_interctc / len(intermediate_outs)
  168. # calculate whole encoder loss
  169. loss_ctc = (
  170. 1 - self.interctc_weight
  171. ) * loss_ctc + self.interctc_weight * loss_interctc
  172. # 2b. Attention decoder branch
  173. if self.ctc_weight != 1.0:
  174. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  175. encoder_out, encoder_out_lens, text, text_lengths
  176. )
  177. # 3. CTC-Att loss definition
  178. if self.ctc_weight == 0.0:
  179. loss = loss_att
  180. elif self.ctc_weight == 1.0:
  181. loss = loss_ctc
  182. else:
  183. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  184. # Collect Attn branch stats
  185. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  186. stats["acc"] = acc_att
  187. stats["cer"] = cer_att
  188. stats["wer"] = wer_att
  189. # Collect total loss stats
  190. stats["loss"] = torch.clone(loss.detach())
  191. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  192. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  193. return loss, stats, weight
  194. def collect_feats(
  195. self,
  196. speech: torch.Tensor,
  197. speech_lengths: torch.Tensor,
  198. text: torch.Tensor,
  199. text_lengths: torch.Tensor,
  200. ) -> Dict[str, torch.Tensor]:
  201. if self.extract_feats_in_collect_stats:
  202. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  203. else:
  204. # Generate dummy stats if extract_feats_in_collect_stats is False
  205. logging.warning(
  206. "Generating dummy stats for feats and feats_lengths, "
  207. "because encoder_conf.extract_feats_in_collect_stats is "
  208. f"{self.extract_feats_in_collect_stats}"
  209. )
  210. feats, feats_lengths = speech, speech_lengths
  211. return {"feats": feats, "feats_lengths": feats_lengths}
  212. def encode(
  213. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  214. ) -> Tuple[torch.Tensor, torch.Tensor]:
  215. """Frontend + Encoder. Note that this method is used by asr_inference.py
  216. Args:
  217. speech: (Batch, Length, ...)
  218. speech_lengths: (Batch, )
  219. """
  220. with autocast(False):
  221. # 1. Extract feats
  222. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  223. # 2. Data augmentation
  224. if self.specaug is not None and self.training:
  225. feats, feats_lengths = self.specaug(feats, feats_lengths)
  226. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  227. if self.normalize is not None:
  228. feats, feats_lengths = self.normalize(feats, feats_lengths)
  229. # Pre-encoder, e.g. used for raw input data
  230. if self.preencoder is not None:
  231. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  232. # 4. Forward encoder
  233. # feats: (Batch, Length, Dim)
  234. # -> encoder_out: (Batch, Length2, Dim2)
  235. if self.encoder.interctc_use_conditioning:
  236. encoder_out, encoder_out_lens, _ = self.encoder(
  237. feats, feats_lengths, ctc=self.ctc
  238. )
  239. else:
  240. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  241. intermediate_outs = None
  242. if isinstance(encoder_out, tuple):
  243. intermediate_outs = encoder_out[1]
  244. encoder_out = encoder_out[0]
  245. # Post-encoder, e.g. NLU
  246. if self.postencoder is not None:
  247. encoder_out, encoder_out_lens = self.postencoder(
  248. encoder_out, encoder_out_lens
  249. )
  250. assert encoder_out.size(0) == speech.size(0), (
  251. encoder_out.size(),
  252. speech.size(0),
  253. )
  254. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  255. encoder_out.size(),
  256. encoder_out_lens.max(),
  257. )
  258. if intermediate_outs is not None:
  259. return (encoder_out, intermediate_outs), encoder_out_lens
  260. return encoder_out, encoder_out_lens
  261. def _extract_feats(
  262. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  263. ) -> Tuple[torch.Tensor, torch.Tensor]:
  264. assert speech_lengths.dim() == 1, speech_lengths.shape
  265. # for data-parallel
  266. speech = speech[:, : speech_lengths.max()]
  267. if self.frontend is not None:
  268. # Frontend
  269. # e.g. STFT and Feature extract
  270. # data_loader may send time-domain signal in this case
  271. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  272. feats, feats_lengths = self.frontend(speech, speech_lengths)
  273. else:
  274. # No frontend and no feature extract
  275. feats, feats_lengths = speech, speech_lengths
  276. return feats, feats_lengths
  277. def nll(
  278. self,
  279. encoder_out: torch.Tensor,
  280. encoder_out_lens: torch.Tensor,
  281. ys_pad: torch.Tensor,
  282. ys_pad_lens: torch.Tensor,
  283. ) -> torch.Tensor:
  284. """Compute negative log likelihood(nll) from transformer-decoder
  285. Normally, this function is called in batchify_nll.
  286. Args:
  287. encoder_out: (Batch, Length, Dim)
  288. encoder_out_lens: (Batch,)
  289. ys_pad: (Batch, Length)
  290. ys_pad_lens: (Batch,)
  291. """
  292. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  293. ys_in_lens = ys_pad_lens + 1
  294. # 1. Forward decoder
  295. decoder_out, _ = self.decoder(
  296. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  297. ) # [batch, seqlen, dim]
  298. batch_size = decoder_out.size(0)
  299. decoder_num_class = decoder_out.size(2)
  300. # nll: negative log-likelihood
  301. nll = torch.nn.functional.cross_entropy(
  302. decoder_out.view(-1, decoder_num_class),
  303. ys_out_pad.view(-1),
  304. ignore_index=self.ignore_id,
  305. reduction="none",
  306. )
  307. nll = nll.view(batch_size, -1)
  308. nll = nll.sum(dim=1)
  309. assert nll.size(0) == batch_size
  310. return nll
  311. def batchify_nll(
  312. self,
  313. encoder_out: torch.Tensor,
  314. encoder_out_lens: torch.Tensor,
  315. ys_pad: torch.Tensor,
  316. ys_pad_lens: torch.Tensor,
  317. batch_size: int = 100,
  318. ):
  319. """Compute negative log likelihood(nll) from transformer-decoder
  320. To avoid OOM, this fuction seperate the input into batches.
  321. Then call nll for each batch and combine and return results.
  322. Args:
  323. encoder_out: (Batch, Length, Dim)
  324. encoder_out_lens: (Batch,)
  325. ys_pad: (Batch, Length)
  326. ys_pad_lens: (Batch,)
  327. batch_size: int, samples each batch contain when computing nll,
  328. you may change this to avoid OOM or increase
  329. GPU memory usage
  330. """
  331. total_num = encoder_out.size(0)
  332. if total_num <= batch_size:
  333. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  334. else:
  335. nll = []
  336. start_idx = 0
  337. while True:
  338. end_idx = min(start_idx + batch_size, total_num)
  339. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  340. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  341. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  342. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  343. batch_nll = self.nll(
  344. batch_encoder_out,
  345. batch_encoder_out_lens,
  346. batch_ys_pad,
  347. batch_ys_pad_lens,
  348. )
  349. nll.append(batch_nll)
  350. start_idx = end_idx
  351. if start_idx == total_num:
  352. break
  353. nll = torch.cat(nll)
  354. assert nll.size(0) == total_num
  355. return nll
  356. def _calc_att_loss(
  357. self,
  358. encoder_out: torch.Tensor,
  359. encoder_out_lens: torch.Tensor,
  360. ys_pad: torch.Tensor,
  361. ys_pad_lens: torch.Tensor,
  362. ):
  363. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  364. ys_in_lens = ys_pad_lens + 1
  365. # 1. Forward decoder
  366. decoder_out, _ = self.decoder(
  367. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  368. )
  369. # 2. Compute attention loss
  370. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  371. acc_att = th_accuracy(
  372. decoder_out.view(-1, self.vocab_size),
  373. ys_out_pad,
  374. ignore_label=self.ignore_id,
  375. )
  376. # Compute cer/wer using attention-decoder
  377. if self.training or self.error_calculator is None:
  378. cer_att, wer_att = None, None
  379. else:
  380. ys_hat = decoder_out.argmax(dim=-1)
  381. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  382. return loss_att, acc_att, cer_att, wer_att
  383. def _calc_ctc_loss(
  384. self,
  385. encoder_out: torch.Tensor,
  386. encoder_out_lens: torch.Tensor,
  387. ys_pad: torch.Tensor,
  388. ys_pad_lens: torch.Tensor,
  389. ):
  390. # Calc CTC loss
  391. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  392. # Calc CER using CTC
  393. cer_ctc = None
  394. if not self.training and self.error_calculator is not None:
  395. ys_hat = self.ctc.argmax(encoder_out).data
  396. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  397. return loss_ctc, cer_ctc