e2e_asr_transducer.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013
  1. """ESPnet2 ASR Transducer model."""
  2. import logging
  3. from contextlib import contextmanager
  4. from typing import Dict, List, Optional, Tuple, Union
  5. import torch
  6. from packaging.version import parse as V
  7. from typeguard import check_argument_types
  8. from funasr.models.frontend.abs_frontend import AbsFrontend
  9. from funasr.models.specaug.abs_specaug import AbsSpecAug
  10. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  11. from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
  12. from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
  13. from funasr.models.joint_net.joint_network import JointNetwork
  14. from funasr.modules.nets_utils import get_transducer_task_io
  15. from funasr.layers.abs_normalize import AbsNormalize
  16. from funasr.torch_utils.device_funcs import force_gatherable
  17. from funasr.train.abs_espnet_model import AbsESPnetModel
  18. if V(torch.__version__) >= V("1.6.0"):
  19. from torch.cuda.amp import autocast
  20. else:
  21. @contextmanager
  22. def autocast(enabled=True):
  23. yield
  24. class TransducerModel(AbsESPnetModel):
  25. """ESPnet2ASRTransducerModel module definition.
  26. Args:
  27. vocab_size: Size of complete vocabulary (w/ EOS and blank included).
  28. token_list: List of token
  29. frontend: Frontend module.
  30. specaug: SpecAugment module.
  31. normalize: Normalization module.
  32. encoder: Encoder module.
  33. decoder: Decoder module.
  34. joint_network: Joint Network module.
  35. transducer_weight: Weight of the Transducer loss.
  36. fastemit_lambda: FastEmit lambda value.
  37. auxiliary_ctc_weight: Weight of auxiliary CTC loss.
  38. auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
  39. auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
  40. auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
  41. ignore_id: Initial padding ID.
  42. sym_space: Space symbol.
  43. sym_blank: Blank Symbol
  44. report_cer: Whether to report Character Error Rate during validation.
  45. report_wer: Whether to report Word Error Rate during validation.
  46. extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
  47. """
  48. def __init__(
  49. self,
  50. vocab_size: int,
  51. token_list: Union[Tuple[str, ...], List[str]],
  52. frontend: Optional[AbsFrontend],
  53. specaug: Optional[AbsSpecAug],
  54. normalize: Optional[AbsNormalize],
  55. encoder: Encoder,
  56. decoder: RNNTDecoder,
  57. joint_network: JointNetwork,
  58. att_decoder: Optional[AbsAttDecoder] = None,
  59. transducer_weight: float = 1.0,
  60. fastemit_lambda: float = 0.0,
  61. auxiliary_ctc_weight: float = 0.0,
  62. auxiliary_ctc_dropout_rate: float = 0.0,
  63. auxiliary_lm_loss_weight: float = 0.0,
  64. auxiliary_lm_loss_smoothing: float = 0.0,
  65. ignore_id: int = -1,
  66. sym_space: str = "<space>",
  67. sym_blank: str = "<blank>",
  68. report_cer: bool = True,
  69. report_wer: bool = True,
  70. extract_feats_in_collect_stats: bool = True,
  71. ) -> None:
  72. """Construct an ESPnetASRTransducerModel object."""
  73. super().__init__()
  74. assert check_argument_types()
  75. # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
  76. self.blank_id = 0
  77. self.vocab_size = vocab_size
  78. self.ignore_id = ignore_id
  79. self.token_list = token_list.copy()
  80. self.sym_space = sym_space
  81. self.sym_blank = sym_blank
  82. self.frontend = frontend
  83. self.specaug = specaug
  84. self.normalize = normalize
  85. self.encoder = encoder
  86. self.decoder = decoder
  87. self.joint_network = joint_network
  88. self.criterion_transducer = None
  89. self.error_calculator = None
  90. self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
  91. self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
  92. if self.use_auxiliary_ctc:
  93. self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
  94. self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
  95. if self.use_auxiliary_lm_loss:
  96. self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
  97. self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
  98. self.transducer_weight = transducer_weight
  99. self.fastemit_lambda = fastemit_lambda
  100. self.auxiliary_ctc_weight = auxiliary_ctc_weight
  101. self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
  102. self.report_cer = report_cer
  103. self.report_wer = report_wer
  104. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  105. def forward(
  106. self,
  107. speech: torch.Tensor,
  108. speech_lengths: torch.Tensor,
  109. text: torch.Tensor,
  110. text_lengths: torch.Tensor,
  111. **kwargs,
  112. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  113. """Forward architecture and compute loss(es).
  114. Args:
  115. speech: Speech sequences. (B, S)
  116. speech_lengths: Speech sequences lengths. (B,)
  117. text: Label ID sequences. (B, L)
  118. text_lengths: Label ID sequences lengths. (B,)
  119. kwargs: Contains "utts_id".
  120. Return:
  121. loss: Main loss value.
  122. stats: Task statistics.
  123. weight: Task weights.
  124. """
  125. assert text_lengths.dim() == 1, text_lengths.shape
  126. assert (
  127. speech.shape[0]
  128. == speech_lengths.shape[0]
  129. == text.shape[0]
  130. == text_lengths.shape[0]
  131. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  132. batch_size = speech.shape[0]
  133. text = text[:, : text_lengths.max()]
  134. # 1. Encoder
  135. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  136. # 2. Transducer-related I/O preparation
  137. decoder_in, target, t_len, u_len = get_transducer_task_io(
  138. text,
  139. encoder_out_lens,
  140. ignore_id=self.ignore_id,
  141. )
  142. # 3. Decoder
  143. self.decoder.set_device(encoder_out.device)
  144. decoder_out = self.decoder(decoder_in, u_len)
  145. # 4. Joint Network
  146. joint_out = self.joint_network(
  147. encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
  148. )
  149. # 5. Losses
  150. loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
  151. encoder_out,
  152. joint_out,
  153. target,
  154. t_len,
  155. u_len,
  156. )
  157. loss_ctc, loss_lm = 0.0, 0.0
  158. if self.use_auxiliary_ctc:
  159. loss_ctc = self._calc_ctc_loss(
  160. encoder_out,
  161. target,
  162. t_len,
  163. u_len,
  164. )
  165. if self.use_auxiliary_lm_loss:
  166. loss_lm = self._calc_lm_loss(decoder_out, target)
  167. loss = (
  168. self.transducer_weight * loss_trans
  169. + self.auxiliary_ctc_weight * loss_ctc
  170. + self.auxiliary_lm_loss_weight * loss_lm
  171. )
  172. stats = dict(
  173. loss=loss.detach(),
  174. loss_transducer=loss_trans.detach(),
  175. aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
  176. aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
  177. cer_transducer=cer_trans,
  178. wer_transducer=wer_trans,
  179. )
  180. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  181. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  182. return loss, stats, weight
  183. def collect_feats(
  184. self,
  185. speech: torch.Tensor,
  186. speech_lengths: torch.Tensor,
  187. text: torch.Tensor,
  188. text_lengths: torch.Tensor,
  189. **kwargs,
  190. ) -> Dict[str, torch.Tensor]:
  191. """Collect features sequences and features lengths sequences.
  192. Args:
  193. speech: Speech sequences. (B, S)
  194. speech_lengths: Speech sequences lengths. (B,)
  195. text: Label ID sequences. (B, L)
  196. text_lengths: Label ID sequences lengths. (B,)
  197. kwargs: Contains "utts_id".
  198. Return:
  199. {}: "feats": Features sequences. (B, T, D_feats),
  200. "feats_lengths": Features sequences lengths. (B,)
  201. """
  202. if self.extract_feats_in_collect_stats:
  203. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  204. else:
  205. # Generate dummy stats if extract_feats_in_collect_stats is False
  206. logging.warning(
  207. "Generating dummy stats for feats and feats_lengths, "
  208. "because encoder_conf.extract_feats_in_collect_stats is "
  209. f"{self.extract_feats_in_collect_stats}"
  210. )
  211. feats, feats_lengths = speech, speech_lengths
  212. return {"feats": feats, "feats_lengths": feats_lengths}
  213. def encode(
  214. self,
  215. speech: torch.Tensor,
  216. speech_lengths: torch.Tensor,
  217. ) -> Tuple[torch.Tensor, torch.Tensor]:
  218. """Encoder speech sequences.
  219. Args:
  220. speech: Speech sequences. (B, S)
  221. speech_lengths: Speech sequences lengths. (B,)
  222. Return:
  223. encoder_out: Encoder outputs. (B, T, D_enc)
  224. encoder_out_lens: Encoder outputs lengths. (B,)
  225. """
  226. with autocast(False):
  227. # 1. Extract feats
  228. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  229. # 2. Data augmentation
  230. if self.specaug is not None and self.training:
  231. feats, feats_lengths = self.specaug(feats, feats_lengths)
  232. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  233. if self.normalize is not None:
  234. feats, feats_lengths = self.normalize(feats, feats_lengths)
  235. # 4. Forward encoder
  236. encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
  237. assert encoder_out.size(0) == speech.size(0), (
  238. encoder_out.size(),
  239. speech.size(0),
  240. )
  241. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  242. encoder_out.size(),
  243. encoder_out_lens.max(),
  244. )
  245. return encoder_out, encoder_out_lens
  246. def _extract_feats(
  247. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  248. ) -> Tuple[torch.Tensor, torch.Tensor]:
  249. """Extract features sequences and features sequences lengths.
  250. Args:
  251. speech: Speech sequences. (B, S)
  252. speech_lengths: Speech sequences lengths. (B,)
  253. Return:
  254. feats: Features sequences. (B, T, D_feats)
  255. feats_lengths: Features sequences lengths. (B,)
  256. """
  257. assert speech_lengths.dim() == 1, speech_lengths.shape
  258. # for data-parallel
  259. speech = speech[:, : speech_lengths.max()]
  260. if self.frontend is not None:
  261. feats, feats_lengths = self.frontend(speech, speech_lengths)
  262. else:
  263. feats, feats_lengths = speech, speech_lengths
  264. return feats, feats_lengths
  265. def _calc_transducer_loss(
  266. self,
  267. encoder_out: torch.Tensor,
  268. joint_out: torch.Tensor,
  269. target: torch.Tensor,
  270. t_len: torch.Tensor,
  271. u_len: torch.Tensor,
  272. ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
  273. """Compute Transducer loss.
  274. Args:
  275. encoder_out: Encoder output sequences. (B, T, D_enc)
  276. joint_out: Joint Network output sequences (B, T, U, D_joint)
  277. target: Target label ID sequences. (B, L)
  278. t_len: Encoder output sequences lengths. (B,)
  279. u_len: Target label ID sequences lengths. (B,)
  280. Return:
  281. loss_transducer: Transducer loss value.
  282. cer_transducer: Character error rate for Transducer.
  283. wer_transducer: Word Error Rate for Transducer.
  284. """
  285. if self.criterion_transducer is None:
  286. try:
  287. # from warprnnt_pytorch import RNNTLoss
  288. # self.criterion_transducer = RNNTLoss(
  289. # reduction="mean",
  290. # fastemit_lambda=self.fastemit_lambda,
  291. # )
  292. from warp_rnnt import rnnt_loss as RNNTLoss
  293. self.criterion_transducer = RNNTLoss
  294. except ImportError:
  295. logging.error(
  296. "warp-rnnt was not installed."
  297. "Please consult the installation documentation."
  298. )
  299. exit(1)
  300. # loss_transducer = self.criterion_transducer(
  301. # joint_out,
  302. # target,
  303. # t_len,
  304. # u_len,
  305. # )
  306. log_probs = torch.log_softmax(joint_out, dim=-1)
  307. loss_transducer = self.criterion_transducer(
  308. log_probs,
  309. target,
  310. t_len,
  311. u_len,
  312. reduction="mean",
  313. blank=self.blank_id,
  314. fastemit_lambda=self.fastemit_lambda,
  315. gather=True,
  316. )
  317. if not self.training and (self.report_cer or self.report_wer):
  318. if self.error_calculator is None:
  319. from espnet2.asr_transducer.error_calculator import ErrorCalculator
  320. self.error_calculator = ErrorCalculator(
  321. self.decoder,
  322. self.joint_network,
  323. self.token_list,
  324. self.sym_space,
  325. self.sym_blank,
  326. report_cer=self.report_cer,
  327. report_wer=self.report_wer,
  328. )
  329. cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
  330. return loss_transducer, cer_transducer, wer_transducer
  331. return loss_transducer, None, None
  332. def _calc_ctc_loss(
  333. self,
  334. encoder_out: torch.Tensor,
  335. target: torch.Tensor,
  336. t_len: torch.Tensor,
  337. u_len: torch.Tensor,
  338. ) -> torch.Tensor:
  339. """Compute CTC loss.
  340. Args:
  341. encoder_out: Encoder output sequences. (B, T, D_enc)
  342. target: Target label ID sequences. (B, L)
  343. t_len: Encoder output sequences lengths. (B,)
  344. u_len: Target label ID sequences lengths. (B,)
  345. Return:
  346. loss_ctc: CTC loss value.
  347. """
  348. ctc_in = self.ctc_lin(
  349. torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
  350. )
  351. ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
  352. target_mask = target != 0
  353. ctc_target = target[target_mask].cpu()
  354. with torch.backends.cudnn.flags(deterministic=True):
  355. loss_ctc = torch.nn.functional.ctc_loss(
  356. ctc_in,
  357. ctc_target,
  358. t_len,
  359. u_len,
  360. zero_infinity=True,
  361. reduction="sum",
  362. )
  363. loss_ctc /= target.size(0)
  364. return loss_ctc
  365. def _calc_lm_loss(
  366. self,
  367. decoder_out: torch.Tensor,
  368. target: torch.Tensor,
  369. ) -> torch.Tensor:
  370. """Compute LM loss.
  371. Args:
  372. decoder_out: Decoder output sequences. (B, U, D_dec)
  373. target: Target label ID sequences. (B, L)
  374. Return:
  375. loss_lm: LM loss value.
  376. """
  377. lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
  378. lm_target = target.view(-1).type(torch.int64)
  379. with torch.no_grad():
  380. true_dist = lm_loss_in.clone()
  381. true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
  382. # Ignore blank ID (0)
  383. ignore = lm_target == 0
  384. lm_target = lm_target.masked_fill(ignore, 0)
  385. true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
  386. loss_lm = torch.nn.functional.kl_div(
  387. torch.log_softmax(lm_loss_in, dim=1),
  388. true_dist,
  389. reduction="none",
  390. )
  391. loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
  392. 0
  393. )
  394. return loss_lm
  395. class UnifiedTransducerModel(AbsESPnetModel):
  396. """ESPnet2ASRTransducerModel module definition.
  397. Args:
  398. vocab_size: Size of complete vocabulary (w/ EOS and blank included).
  399. token_list: List of token
  400. frontend: Frontend module.
  401. specaug: SpecAugment module.
  402. normalize: Normalization module.
  403. encoder: Encoder module.
  404. decoder: Decoder module.
  405. joint_network: Joint Network module.
  406. transducer_weight: Weight of the Transducer loss.
  407. fastemit_lambda: FastEmit lambda value.
  408. auxiliary_ctc_weight: Weight of auxiliary CTC loss.
  409. auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
  410. auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
  411. auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
  412. ignore_id: Initial padding ID.
  413. sym_space: Space symbol.
  414. sym_blank: Blank Symbol
  415. report_cer: Whether to report Character Error Rate during validation.
  416. report_wer: Whether to report Word Error Rate during validation.
  417. extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
  418. """
  419. def __init__(
  420. self,
  421. vocab_size: int,
  422. token_list: Union[Tuple[str, ...], List[str]],
  423. frontend: Optional[AbsFrontend],
  424. specaug: Optional[AbsSpecAug],
  425. normalize: Optional[AbsNormalize],
  426. encoder: Encoder,
  427. decoder: RNNTDecoder,
  428. joint_network: JointNetwork,
  429. att_decoder: Optional[AbsAttDecoder] = None,
  430. transducer_weight: float = 1.0,
  431. fastemit_lambda: float = 0.0,
  432. auxiliary_ctc_weight: float = 0.0,
  433. auxiliary_att_weight: float = 0.0,
  434. auxiliary_ctc_dropout_rate: float = 0.0,
  435. auxiliary_lm_loss_weight: float = 0.0,
  436. auxiliary_lm_loss_smoothing: float = 0.0,
  437. ignore_id: int = -1,
  438. sym_space: str = "<space>",
  439. sym_blank: str = "<blank>",
  440. report_cer: bool = True,
  441. report_wer: bool = True,
  442. sym_sos: str = "<sos/eos>",
  443. sym_eos: str = "<sos/eos>",
  444. extract_feats_in_collect_stats: bool = True,
  445. lsm_weight: float = 0.0,
  446. length_normalized_loss: bool = False,
  447. ) -> None:
  448. """Construct an ESPnetASRTransducerModel object."""
  449. super().__init__()
  450. assert check_argument_types()
  451. # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
  452. self.blank_id = 0
  453. if sym_sos in token_list:
  454. self.sos = token_list.index(sym_sos)
  455. else:
  456. self.sos = vocab_size - 1
  457. if sym_eos in token_list:
  458. self.eos = token_list.index(sym_eos)
  459. else:
  460. self.eos = vocab_size - 1
  461. self.vocab_size = vocab_size
  462. self.ignore_id = ignore_id
  463. self.token_list = token_list.copy()
  464. self.sym_space = sym_space
  465. self.sym_blank = sym_blank
  466. self.frontend = frontend
  467. self.specaug = specaug
  468. self.normalize = normalize
  469. self.encoder = encoder
  470. self.decoder = decoder
  471. self.joint_network = joint_network
  472. self.criterion_transducer = None
  473. self.error_calculator = None
  474. self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
  475. self.use_auxiliary_att = auxiliary_att_weight > 0
  476. self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
  477. if self.use_auxiliary_ctc:
  478. self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
  479. self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
  480. if self.use_auxiliary_att:
  481. self.att_decoder = att_decoder
  482. self.criterion_att = LabelSmoothingLoss(
  483. size=vocab_size,
  484. padding_idx=ignore_id,
  485. smoothing=lsm_weight,
  486. normalize_length=length_normalized_loss,
  487. )
  488. if self.use_auxiliary_lm_loss:
  489. self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
  490. self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
  491. self.transducer_weight = transducer_weight
  492. self.fastemit_lambda = fastemit_lambda
  493. self.auxiliary_ctc_weight = auxiliary_ctc_weight
  494. self.auxiliary_att_weight = auxiliary_att_weight
  495. self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
  496. self.report_cer = report_cer
  497. self.report_wer = report_wer
  498. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  499. def forward(
  500. self,
  501. speech: torch.Tensor,
  502. speech_lengths: torch.Tensor,
  503. text: torch.Tensor,
  504. text_lengths: torch.Tensor,
  505. **kwargs,
  506. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  507. """Forward architecture and compute loss(es).
  508. Args:
  509. speech: Speech sequences. (B, S)
  510. speech_lengths: Speech sequences lengths. (B,)
  511. text: Label ID sequences. (B, L)
  512. text_lengths: Label ID sequences lengths. (B,)
  513. kwargs: Contains "utts_id".
  514. Return:
  515. loss: Main loss value.
  516. stats: Task statistics.
  517. weight: Task weights.
  518. """
  519. assert text_lengths.dim() == 1, text_lengths.shape
  520. assert (
  521. speech.shape[0]
  522. == speech_lengths.shape[0]
  523. == text.shape[0]
  524. == text_lengths.shape[0]
  525. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  526. batch_size = speech.shape[0]
  527. text = text[:, : text_lengths.max()]
  528. #print(speech.shape)
  529. # 1. Encoder
  530. encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
  531. loss_att, loss_att_chunk = 0.0, 0.0
  532. if self.use_auxiliary_att:
  533. loss_att, _ = self._calc_att_loss(
  534. encoder_out, encoder_out_lens, text, text_lengths
  535. )
  536. loss_att_chunk, _ = self._calc_att_loss(
  537. encoder_out_chunk, encoder_out_lens, text, text_lengths
  538. )
  539. # 2. Transducer-related I/O preparation
  540. decoder_in, target, t_len, u_len = get_transducer_task_io(
  541. text,
  542. encoder_out_lens,
  543. ignore_id=self.ignore_id,
  544. )
  545. # 3. Decoder
  546. self.decoder.set_device(encoder_out.device)
  547. decoder_out = self.decoder(decoder_in, u_len)
  548. # 4. Joint Network
  549. joint_out = self.joint_network(
  550. encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
  551. )
  552. joint_out_chunk = self.joint_network(
  553. encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
  554. )
  555. # 5. Losses
  556. loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
  557. encoder_out,
  558. joint_out,
  559. target,
  560. t_len,
  561. u_len,
  562. )
  563. loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
  564. encoder_out_chunk,
  565. joint_out_chunk,
  566. target,
  567. t_len,
  568. u_len,
  569. )
  570. loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
  571. if self.use_auxiliary_ctc:
  572. loss_ctc = self._calc_ctc_loss(
  573. encoder_out,
  574. target,
  575. t_len,
  576. u_len,
  577. )
  578. loss_ctc_chunk = self._calc_ctc_loss(
  579. encoder_out_chunk,
  580. target,
  581. t_len,
  582. u_len,
  583. )
  584. if self.use_auxiliary_lm_loss:
  585. loss_lm = self._calc_lm_loss(decoder_out, target)
  586. loss_trans = loss_trans_utt + loss_trans_chunk
  587. loss_ctc = loss_ctc + loss_ctc_chunk
  588. loss_ctc = loss_att + loss_att_chunk
  589. loss = (
  590. self.transducer_weight * loss_trans
  591. + self.auxiliary_ctc_weight * loss_ctc
  592. + self.auxiliary_att_weight * loss_att
  593. + self.auxiliary_lm_loss_weight * loss_lm
  594. )
  595. stats = dict(
  596. loss=loss.detach(),
  597. loss_transducer=loss_trans_utt.detach(),
  598. loss_transducer_chunk=loss_trans_chunk.detach(),
  599. aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
  600. aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
  601. aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
  602. aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
  603. aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
  604. cer_transducer=cer_trans,
  605. wer_transducer=wer_trans,
  606. cer_transducer_chunk=cer_trans_chunk,
  607. wer_transducer_chunk=wer_trans_chunk,
  608. )
  609. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  610. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  611. return loss, stats, weight
  612. def collect_feats(
  613. self,
  614. speech: torch.Tensor,
  615. speech_lengths: torch.Tensor,
  616. text: torch.Tensor,
  617. text_lengths: torch.Tensor,
  618. **kwargs,
  619. ) -> Dict[str, torch.Tensor]:
  620. """Collect features sequences and features lengths sequences.
  621. Args:
  622. speech: Speech sequences. (B, S)
  623. speech_lengths: Speech sequences lengths. (B,)
  624. text: Label ID sequences. (B, L)
  625. text_lengths: Label ID sequences lengths. (B,)
  626. kwargs: Contains "utts_id".
  627. Return:
  628. {}: "feats": Features sequences. (B, T, D_feats),
  629. "feats_lengths": Features sequences lengths. (B,)
  630. """
  631. if self.extract_feats_in_collect_stats:
  632. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  633. else:
  634. # Generate dummy stats if extract_feats_in_collect_stats is False
  635. logging.warning(
  636. "Generating dummy stats for feats and feats_lengths, "
  637. "because encoder_conf.extract_feats_in_collect_stats is "
  638. f"{self.extract_feats_in_collect_stats}"
  639. )
  640. feats, feats_lengths = speech, speech_lengths
  641. return {"feats": feats, "feats_lengths": feats_lengths}
  642. def encode(
  643. self,
  644. speech: torch.Tensor,
  645. speech_lengths: torch.Tensor,
  646. ) -> Tuple[torch.Tensor, torch.Tensor]:
  647. """Encoder speech sequences.
  648. Args:
  649. speech: Speech sequences. (B, S)
  650. speech_lengths: Speech sequences lengths. (B,)
  651. Return:
  652. encoder_out: Encoder outputs. (B, T, D_enc)
  653. encoder_out_lens: Encoder outputs lengths. (B,)
  654. """
  655. with autocast(False):
  656. # 1. Extract feats
  657. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  658. # 2. Data augmentation
  659. if self.specaug is not None and self.training:
  660. feats, feats_lengths = self.specaug(feats, feats_lengths)
  661. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  662. if self.normalize is not None:
  663. feats, feats_lengths = self.normalize(feats, feats_lengths)
  664. # 4. Forward encoder
  665. encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
  666. assert encoder_out.size(0) == speech.size(0), (
  667. encoder_out.size(),
  668. speech.size(0),
  669. )
  670. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  671. encoder_out.size(),
  672. encoder_out_lens.max(),
  673. )
  674. return encoder_out, encoder_out_chunk, encoder_out_lens
  675. def _extract_feats(
  676. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  677. ) -> Tuple[torch.Tensor, torch.Tensor]:
  678. """Extract features sequences and features sequences lengths.
  679. Args:
  680. speech: Speech sequences. (B, S)
  681. speech_lengths: Speech sequences lengths. (B,)
  682. Return:
  683. feats: Features sequences. (B, T, D_feats)
  684. feats_lengths: Features sequences lengths. (B,)
  685. """
  686. assert speech_lengths.dim() == 1, speech_lengths.shape
  687. # for data-parallel
  688. speech = speech[:, : speech_lengths.max()]
  689. if self.frontend is not None:
  690. feats, feats_lengths = self.frontend(speech, speech_lengths)
  691. else:
  692. feats, feats_lengths = speech, speech_lengths
  693. return feats, feats_lengths
  694. def _calc_transducer_loss(
  695. self,
  696. encoder_out: torch.Tensor,
  697. joint_out: torch.Tensor,
  698. target: torch.Tensor,
  699. t_len: torch.Tensor,
  700. u_len: torch.Tensor,
  701. ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
  702. """Compute Transducer loss.
  703. Args:
  704. encoder_out: Encoder output sequences. (B, T, D_enc)
  705. joint_out: Joint Network output sequences (B, T, U, D_joint)
  706. target: Target label ID sequences. (B, L)
  707. t_len: Encoder output sequences lengths. (B,)
  708. u_len: Target label ID sequences lengths. (B,)
  709. Return:
  710. loss_transducer: Transducer loss value.
  711. cer_transducer: Character error rate for Transducer.
  712. wer_transducer: Word Error Rate for Transducer.
  713. """
  714. if self.criterion_transducer is None:
  715. try:
  716. # from warprnnt_pytorch import RNNTLoss
  717. # self.criterion_transducer = RNNTLoss(
  718. # reduction="mean",
  719. # fastemit_lambda=self.fastemit_lambda,
  720. # )
  721. from warp_rnnt import rnnt_loss as RNNTLoss
  722. self.criterion_transducer = RNNTLoss
  723. except ImportError:
  724. logging.error(
  725. "warp-rnnt was not installed."
  726. "Please consult the installation documentation."
  727. )
  728. exit(1)
  729. # loss_transducer = self.criterion_transducer(
  730. # joint_out,
  731. # target,
  732. # t_len,
  733. # u_len,
  734. # )
  735. log_probs = torch.log_softmax(joint_out, dim=-1)
  736. loss_transducer = self.criterion_transducer(
  737. log_probs,
  738. target,
  739. t_len,
  740. u_len,
  741. reduction="mean",
  742. blank=self.blank_id,
  743. fastemit_lambda=self.fastemit_lambda,
  744. gather=True,
  745. )
  746. if not self.training and (self.report_cer or self.report_wer):
  747. if self.error_calculator is None:
  748. self.error_calculator = ErrorCalculator(
  749. self.decoder,
  750. self.joint_network,
  751. self.token_list,
  752. self.sym_space,
  753. self.sym_blank,
  754. report_cer=self.report_cer,
  755. report_wer=self.report_wer,
  756. )
  757. cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
  758. return loss_transducer, cer_transducer, wer_transducer
  759. return loss_transducer, None, None
  760. def _calc_ctc_loss(
  761. self,
  762. encoder_out: torch.Tensor,
  763. target: torch.Tensor,
  764. t_len: torch.Tensor,
  765. u_len: torch.Tensor,
  766. ) -> torch.Tensor:
  767. """Compute CTC loss.
  768. Args:
  769. encoder_out: Encoder output sequences. (B, T, D_enc)
  770. target: Target label ID sequences. (B, L)
  771. t_len: Encoder output sequences lengths. (B,)
  772. u_len: Target label ID sequences lengths. (B,)
  773. Return:
  774. loss_ctc: CTC loss value.
  775. """
  776. ctc_in = self.ctc_lin(
  777. torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
  778. )
  779. ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
  780. target_mask = target != 0
  781. ctc_target = target[target_mask].cpu()
  782. with torch.backends.cudnn.flags(deterministic=True):
  783. loss_ctc = torch.nn.functional.ctc_loss(
  784. ctc_in,
  785. ctc_target,
  786. t_len,
  787. u_len,
  788. zero_infinity=True,
  789. reduction="sum",
  790. )
  791. loss_ctc /= target.size(0)
  792. return loss_ctc
  793. def _calc_lm_loss(
  794. self,
  795. decoder_out: torch.Tensor,
  796. target: torch.Tensor,
  797. ) -> torch.Tensor:
  798. """Compute LM loss.
  799. Args:
  800. decoder_out: Decoder output sequences. (B, U, D_dec)
  801. target: Target label ID sequences. (B, L)
  802. Return:
  803. loss_lm: LM loss value.
  804. """
  805. lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
  806. lm_target = target.view(-1).type(torch.int64)
  807. with torch.no_grad():
  808. true_dist = lm_loss_in.clone()
  809. true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
  810. # Ignore blank ID (0)
  811. ignore = lm_target == 0
  812. lm_target = lm_target.masked_fill(ignore, 0)
  813. true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
  814. loss_lm = torch.nn.functional.kl_div(
  815. torch.log_softmax(lm_loss_in, dim=1),
  816. true_dist,
  817. reduction="none",
  818. )
  819. loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
  820. 0
  821. )
  822. return loss_lm
  823. def _calc_att_loss(
  824. self,
  825. encoder_out: torch.Tensor,
  826. encoder_out_lens: torch.Tensor,
  827. ys_pad: torch.Tensor,
  828. ys_pad_lens: torch.Tensor,
  829. ):
  830. if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
  831. ys_pad = torch.cat(
  832. [
  833. self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
  834. ys_pad,
  835. ],
  836. dim=1,
  837. )
  838. ys_pad_lens += 1
  839. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  840. ys_in_lens = ys_pad_lens + 1
  841. # 1. Forward decoder
  842. decoder_out, _ = self.att_decoder(
  843. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  844. )
  845. # 2. Compute attention loss
  846. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  847. acc_att = th_accuracy(
  848. decoder_out.view(-1, self.vocab_size),
  849. ys_out_pad,
  850. ignore_label=self.ignore_id,
  851. )
  852. return loss_att, acc_att