e2e_sa_asr.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import logging
  4. from contextlib import contextmanager
  5. from distutils.version import LooseVersion
  6. from typing import Dict
  7. from typing import List
  8. from typing import Optional
  9. from typing import Tuple
  10. from typing import Union
  11. import torch
  12. import torch.nn.functional as F
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, NllLoss # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  23. from funasr.models.specaug.abs_specaug import AbsSpecAug
  24. from funasr.modules.add_sos_eos import add_sos_eos
  25. from funasr.modules.e2e_asr_common import ErrorCalculator
  26. from funasr.modules.nets_utils import th_accuracy
  27. from funasr.torch_utils.device_funcs import force_gatherable
  28. from funasr.models.base_model import FunASRModel
  29. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  30. from torch.cuda.amp import autocast
  31. else:
  32. # Nothing to do if torch<1.6.0
  33. @contextmanager
  34. def autocast(enabled=True):
  35. yield
  36. class SAASRModel(FunASRModel):
  37. """CTC-attention hybrid Encoder-Decoder model"""
  38. def __init__(
  39. self,
  40. vocab_size: int,
  41. max_spk_num: int,
  42. token_list: Union[Tuple[str, ...], List[str]],
  43. frontend: Optional[AbsFrontend],
  44. specaug: Optional[AbsSpecAug],
  45. normalize: Optional[AbsNormalize],
  46. asr_encoder: AbsEncoder,
  47. spk_encoder: torch.nn.Module,
  48. decoder: AbsDecoder,
  49. ctc: CTC,
  50. spk_weight: float = 0.5,
  51. ctc_weight: float = 0.5,
  52. interctc_weight: float = 0.0,
  53. ignore_id: int = -1,
  54. lsm_weight: float = 0.0,
  55. length_normalized_loss: bool = False,
  56. report_cer: bool = True,
  57. report_wer: bool = True,
  58. sym_space: str = "<space>",
  59. sym_blank: str = "<blank>",
  60. extract_feats_in_collect_stats: bool = True,
  61. ):
  62. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  63. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  64. super().__init__()
  65. # note that eos is the same as sos (equivalent ID)
  66. self.blank_id = 0
  67. self.sos = 1
  68. self.eos = 2
  69. self.vocab_size = vocab_size
  70. self.max_spk_num=max_spk_num
  71. self.ignore_id = ignore_id
  72. self.spk_weight = spk_weight
  73. self.ctc_weight = ctc_weight
  74. self.interctc_weight = interctc_weight
  75. self.token_list = token_list.copy()
  76. self.frontend = frontend
  77. self.specaug = specaug
  78. self.normalize = normalize
  79. self.asr_encoder = asr_encoder
  80. self.spk_encoder = spk_encoder
  81. if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
  82. self.asr_encoder.interctc_use_conditioning = False
  83. if self.asr_encoder.interctc_use_conditioning:
  84. self.asr_encoder.conditioning_layer = torch.nn.Linear(
  85. vocab_size, self.asr_encoder.output_size()
  86. )
  87. self.error_calculator = None
  88. # we set self.decoder = None in the CTC mode since
  89. # self.decoder parameters were never used and PyTorch complained
  90. # and threw an Exception in the multi-GPU experiment.
  91. # thanks Jeff Farris for pointing out the issue.
  92. if ctc_weight == 1.0:
  93. self.decoder = None
  94. else:
  95. self.decoder = decoder
  96. self.criterion_att = LabelSmoothingLoss(
  97. size=vocab_size,
  98. padding_idx=ignore_id,
  99. smoothing=lsm_weight,
  100. normalize_length=length_normalized_loss,
  101. )
  102. self.criterion_spk = NllLoss(
  103. size=max_spk_num,
  104. padding_idx=ignore_id,
  105. normalize_length=length_normalized_loss,
  106. )
  107. if report_cer or report_wer:
  108. self.error_calculator = ErrorCalculator(
  109. token_list, sym_space, sym_blank, report_cer, report_wer
  110. )
  111. if ctc_weight == 0.0:
  112. self.ctc = None
  113. else:
  114. self.ctc = ctc
  115. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  116. def forward(
  117. self,
  118. speech: torch.Tensor,
  119. speech_lengths: torch.Tensor,
  120. text: torch.Tensor,
  121. text_lengths: torch.Tensor,
  122. profile: torch.Tensor,
  123. profile_lengths: torch.Tensor,
  124. text_id: torch.Tensor,
  125. text_id_lengths: torch.Tensor
  126. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  127. """Frontend + Encoder + Decoder + Calc loss
  128. Args:
  129. speech: (Batch, Length, ...)
  130. speech_lengths: (Batch, )
  131. text: (Batch, Length)
  132. text_lengths: (Batch,)
  133. profile: (Batch, Length, Dim)
  134. profile_lengths: (Batch,)
  135. """
  136. assert text_lengths.dim() == 1, text_lengths.shape
  137. # Check that batch_size is unified
  138. assert (
  139. speech.shape[0]
  140. == speech_lengths.shape[0]
  141. == text.shape[0]
  142. == text_lengths.shape[0]
  143. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  144. batch_size = speech.shape[0]
  145. # for data-parallel
  146. text = text[:, : text_lengths.max()]
  147. # 1. Encoder
  148. asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
  149. intermediate_outs = None
  150. if isinstance(asr_encoder_out, tuple):
  151. intermediate_outs = asr_encoder_out[1]
  152. asr_encoder_out = asr_encoder_out[0]
  153. loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
  154. loss_ctc, cer_ctc = None, None
  155. stats = dict()
  156. # 1. CTC branch
  157. if self.ctc_weight != 0.0:
  158. loss_ctc, cer_ctc = self._calc_ctc_loss(
  159. asr_encoder_out, encoder_out_lens, text, text_lengths
  160. )
  161. # Intermediate CTC (optional)
  162. loss_interctc = 0.0
  163. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  164. for layer_idx, intermediate_out in intermediate_outs:
  165. # we assume intermediate_out has the same length & padding
  166. # as those of encoder_out
  167. loss_ic, cer_ic = self._calc_ctc_loss(
  168. intermediate_out, encoder_out_lens, text, text_lengths
  169. )
  170. loss_interctc = loss_interctc + loss_ic
  171. # Collect Intermedaite CTC stats
  172. stats["loss_interctc_layer{}".format(layer_idx)] = (
  173. loss_ic.detach() if loss_ic is not None else None
  174. )
  175. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  176. loss_interctc = loss_interctc / len(intermediate_outs)
  177. # calculate whole encoder loss
  178. loss_ctc = (
  179. 1 - self.interctc_weight
  180. ) * loss_ctc + self.interctc_weight * loss_interctc
  181. # 2b. Attention decoder branch
  182. if self.ctc_weight != 1.0:
  183. loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
  184. asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
  185. )
  186. # 3. CTC-Att loss definition
  187. if self.ctc_weight == 0.0:
  188. loss_asr = loss_att
  189. elif self.ctc_weight == 1.0:
  190. loss_asr = loss_ctc
  191. else:
  192. loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  193. if self.spk_weight == 0.0:
  194. loss = loss_asr
  195. else:
  196. loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
  197. stats = dict(
  198. loss=loss.detach(),
  199. loss_asr=loss_asr.detach(),
  200. loss_att=loss_att.detach() if loss_att is not None else None,
  201. loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
  202. loss_spk=loss_spk.detach() if loss_spk is not None else None,
  203. acc=acc_att,
  204. acc_spk=acc_spk,
  205. cer=cer_att,
  206. wer=wer_att,
  207. cer_ctc=cer_ctc,
  208. )
  209. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  210. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  211. return loss, stats, weight
  212. def collect_feats(
  213. self,
  214. speech: torch.Tensor,
  215. speech_lengths: torch.Tensor,
  216. text: torch.Tensor,
  217. text_lengths: torch.Tensor,
  218. ) -> Dict[str, torch.Tensor]:
  219. if self.extract_feats_in_collect_stats:
  220. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  221. else:
  222. # Generate dummy stats if extract_feats_in_collect_stats is False
  223. logging.warning(
  224. "Generating dummy stats for feats and feats_lengths, "
  225. "because encoder_conf.extract_feats_in_collect_stats is "
  226. f"{self.extract_feats_in_collect_stats}"
  227. )
  228. feats, feats_lengths = speech, speech_lengths
  229. return {"feats": feats, "feats_lengths": feats_lengths}
  230. def encode(
  231. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  232. ) -> Tuple[torch.Tensor, torch.Tensor]:
  233. """Frontend + Encoder. Note that this method is used by asr_inference.py
  234. Args:
  235. speech: (Batch, Length, ...)
  236. speech_lengths: (Batch, )
  237. """
  238. with autocast(False):
  239. # 1. Extract feats
  240. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  241. # 2. Data augmentation
  242. feats_raw = feats.clone()
  243. if self.specaug is not None and self.training:
  244. feats, feats_lengths = self.specaug(feats, feats_lengths)
  245. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  246. if self.normalize is not None:
  247. feats, feats_lengths = self.normalize(feats, feats_lengths)
  248. # 4. Forward encoder
  249. # feats: (Batch, Length, Dim)
  250. # -> encoder_out: (Batch, Length2, Dim2)
  251. if self.asr_encoder.interctc_use_conditioning:
  252. encoder_out, encoder_out_lens, _ = self.asr_encoder(
  253. feats, feats_lengths, ctc=self.ctc
  254. )
  255. else:
  256. encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
  257. intermediate_outs = None
  258. if isinstance(encoder_out, tuple):
  259. intermediate_outs = encoder_out[1]
  260. encoder_out = encoder_out[0]
  261. encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
  262. # import ipdb;ipdb.set_trace()
  263. if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
  264. encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
  265. else:
  266. encoder_out_spk=encoder_out_spk_ori
  267. assert encoder_out.size(0) == speech.size(0), (
  268. encoder_out.size(),
  269. speech.size(0),
  270. )
  271. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  272. encoder_out.size(),
  273. encoder_out_lens.max(),
  274. )
  275. assert encoder_out_spk.size(0) == speech.size(0), (
  276. encoder_out_spk.size(),
  277. speech.size(0),
  278. )
  279. if intermediate_outs is not None:
  280. return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
  281. return encoder_out, encoder_out_lens, encoder_out_spk
  282. def _extract_feats(
  283. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  284. ) -> Tuple[torch.Tensor, torch.Tensor]:
  285. assert speech_lengths.dim() == 1, speech_lengths.shape
  286. # for data-parallel
  287. speech = speech[:, : speech_lengths.max()]
  288. if self.frontend is not None:
  289. # Frontend
  290. # e.g. STFT and Feature extract
  291. # data_loader may send time-domain signal in this case
  292. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  293. feats, feats_lengths = self.frontend(speech, speech_lengths)
  294. else:
  295. # No frontend and no feature extract
  296. feats, feats_lengths = speech, speech_lengths
  297. return feats, feats_lengths
  298. def nll(
  299. self,
  300. encoder_out: torch.Tensor,
  301. encoder_out_lens: torch.Tensor,
  302. ys_pad: torch.Tensor,
  303. ys_pad_lens: torch.Tensor,
  304. ) -> torch.Tensor:
  305. """Compute negative log likelihood(nll) from transformer-decoder
  306. Normally, this function is called in batchify_nll.
  307. Args:
  308. encoder_out: (Batch, Length, Dim)
  309. encoder_out_lens: (Batch,)
  310. ys_pad: (Batch, Length)
  311. ys_pad_lens: (Batch,)
  312. """
  313. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  314. ys_in_lens = ys_pad_lens + 1
  315. # 1. Forward decoder
  316. decoder_out, _ = self.decoder(
  317. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  318. ) # [batch, seqlen, dim]
  319. batch_size = decoder_out.size(0)
  320. decoder_num_class = decoder_out.size(2)
  321. # nll: negative log-likelihood
  322. nll = torch.nn.functional.cross_entropy(
  323. decoder_out.view(-1, decoder_num_class),
  324. ys_out_pad.view(-1),
  325. ignore_index=self.ignore_id,
  326. reduction="none",
  327. )
  328. nll = nll.view(batch_size, -1)
  329. nll = nll.sum(dim=1)
  330. assert nll.size(0) == batch_size
  331. return nll
  332. def batchify_nll(
  333. self,
  334. encoder_out: torch.Tensor,
  335. encoder_out_lens: torch.Tensor,
  336. ys_pad: torch.Tensor,
  337. ys_pad_lens: torch.Tensor,
  338. batch_size: int = 100,
  339. ):
  340. """Compute negative log likelihood(nll) from transformer-decoder
  341. To avoid OOM, this fuction seperate the input into batches.
  342. Then call nll for each batch and combine and return results.
  343. Args:
  344. encoder_out: (Batch, Length, Dim)
  345. encoder_out_lens: (Batch,)
  346. ys_pad: (Batch, Length)
  347. ys_pad_lens: (Batch,)
  348. batch_size: int, samples each batch contain when computing nll,
  349. you may change this to avoid OOM or increase
  350. GPU memory usage
  351. """
  352. total_num = encoder_out.size(0)
  353. if total_num <= batch_size:
  354. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  355. else:
  356. nll = []
  357. start_idx = 0
  358. while True:
  359. end_idx = min(start_idx + batch_size, total_num)
  360. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  361. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  362. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  363. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  364. batch_nll = self.nll(
  365. batch_encoder_out,
  366. batch_encoder_out_lens,
  367. batch_ys_pad,
  368. batch_ys_pad_lens,
  369. )
  370. nll.append(batch_nll)
  371. start_idx = end_idx
  372. if start_idx == total_num:
  373. break
  374. nll = torch.cat(nll)
  375. assert nll.size(0) == total_num
  376. return nll
  377. def _calc_att_loss(
  378. self,
  379. asr_encoder_out: torch.Tensor,
  380. spk_encoder_out: torch.Tensor,
  381. encoder_out_lens: torch.Tensor,
  382. ys_pad: torch.Tensor,
  383. ys_pad_lens: torch.Tensor,
  384. profile: torch.Tensor,
  385. profile_lens: torch.Tensor,
  386. text_id: torch.Tensor,
  387. text_id_lengths: torch.Tensor
  388. ):
  389. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  390. ys_in_lens = ys_pad_lens + 1
  391. # 1. Forward decoder
  392. decoder_out, weights_no_pad, _ = self.decoder(
  393. asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
  394. )
  395. spk_num_no_pad=weights_no_pad.size(-1)
  396. pad=(0,self.max_spk_num-spk_num_no_pad)
  397. weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
  398. # pre_id=weights.argmax(-1)
  399. # pre_text=decoder_out.argmax(-1)
  400. # id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
  401. # pre_text_mask=pre_text*id_mask+1-id_mask #相同的地方不变,不同的地方设为1(<unk>)
  402. # padding_mask= ys_out_pad != self.ignore_id
  403. # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
  404. # denominator = torch.sum(padding_mask)
  405. # sd_acc = float(numerator) / float(denominator)
  406. # 2. Compute attention loss
  407. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  408. loss_spk = self.criterion_spk(torch.log(weights), text_id)
  409. acc_spk= th_accuracy(
  410. weights.view(-1, self.max_spk_num),
  411. text_id,
  412. ignore_label=self.ignore_id,
  413. )
  414. acc_att = th_accuracy(
  415. decoder_out.view(-1, self.vocab_size),
  416. ys_out_pad,
  417. ignore_label=self.ignore_id,
  418. )
  419. # Compute cer/wer using attention-decoder
  420. if self.training or self.error_calculator is None:
  421. cer_att, wer_att = None, None
  422. else:
  423. ys_hat = decoder_out.argmax(dim=-1)
  424. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  425. return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
  426. def _calc_ctc_loss(
  427. self,
  428. encoder_out: torch.Tensor,
  429. encoder_out_lens: torch.Tensor,
  430. ys_pad: torch.Tensor,
  431. ys_pad_lens: torch.Tensor,
  432. ):
  433. # Calc CTC loss
  434. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  435. # Calc CER using CTC
  436. cer_ctc = None
  437. if not self.training and self.error_calculator is not None:
  438. ys_hat = self.ctc.argmax(encoder_out).data
  439. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  440. return loss_ctc, cer_ctc