model.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. import logging
  2. from typing import Union, Dict, List, Tuple, Optional
  3. import time
  4. import torch
  5. import numpy as np
  6. import torch.nn as nn
  7. from torch.cuda.amp import autocast
  8. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  9. from funasr.models.ctc.ctc import CTC
  10. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  11. from funasr.metrics.compute_acc import th_accuracy
  12. from funasr.train_utils.device_funcs import force_gatherable
  13. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  14. from funasr.utils import postprocess_utils
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.register import tables
  17. @tables.register("model_classes", "OpenAIWhisperModel")
  18. class OpenAIWhisperModel(nn.Module):
  19. """CTC-attention hybrid Encoder-Decoder model"""
  20. def __init__(
  21. self,
  22. specaug: str = None,
  23. specaug_conf: dict = None,
  24. normalize: str = None,
  25. normalize_conf: dict = None,
  26. encoder: str = None,
  27. encoder_conf: dict = None,
  28. decoder: str = None,
  29. decoder_conf: dict = None,
  30. ctc: str = None,
  31. ctc_conf: dict = None,
  32. ctc_weight: float = 0.5,
  33. interctc_weight: float = 0.0,
  34. input_size: int = 80,
  35. vocab_size: int = -1,
  36. ignore_id: int = -1,
  37. blank_id: int = 0,
  38. sos: int = 1,
  39. eos: int = 2,
  40. lsm_weight: float = 0.0,
  41. length_normalized_loss: bool = False,
  42. report_cer: bool = True,
  43. report_wer: bool = True,
  44. sym_space: str = "<space>",
  45. sym_blank: str = "<blank>",
  46. # extract_feats_in_collect_stats: bool = True,
  47. share_embedding: bool = False,
  48. # preencoder: Optional[AbsPreEncoder] = None,
  49. # postencoder: Optional[AbsPostEncoder] = None,
  50. **kwargs,
  51. ):
  52. super().__init__()
  53. if specaug is not None:
  54. specaug_class = tables.specaug_classes.get(specaug)
  55. specaug = specaug_class(**specaug_conf)
  56. if normalize is not None:
  57. normalize_class = tables.normalize_classes.get(normalize)
  58. normalize = normalize_class(**normalize_conf)
  59. encoder_class = tables.encoder_classes.get(encoder)
  60. encoder = encoder_class(input_size=input_size, **encoder_conf)
  61. encoder_output_size = encoder.output_size()
  62. if decoder is not None:
  63. decoder_class = tables.decoder_classes.get(decoder)
  64. decoder = decoder_class(decoder_conf)
  65. if ctc_weight > 0.0:
  66. if ctc_conf is None:
  67. ctc_conf = {}
  68. ctc = CTC(
  69. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  70. )
  71. self.blank_id = blank_id
  72. self.sos = sos if sos is not None else vocab_size - 1
  73. self.eos = eos if eos is not None else vocab_size - 1
  74. self.vocab_size = vocab_size
  75. self.ignore_id = ignore_id
  76. self.ctc_weight = ctc_weight
  77. self.specaug = specaug
  78. self.normalize = normalize
  79. self.encoder = encoder
  80. if not hasattr(self.encoder, "interctc_use_conditioning"):
  81. self.encoder.interctc_use_conditioning = False
  82. if self.encoder.interctc_use_conditioning:
  83. self.encoder.conditioning_layer = torch.nn.Linear(
  84. vocab_size, self.encoder.output_size()
  85. )
  86. self.interctc_weight = interctc_weight
  87. # self.error_calculator = None
  88. if ctc_weight == 1.0:
  89. self.decoder = None
  90. else:
  91. self.decoder = decoder
  92. self.criterion_att = LabelSmoothingLoss(
  93. size=vocab_size,
  94. padding_idx=ignore_id,
  95. smoothing=lsm_weight,
  96. normalize_length=length_normalized_loss,
  97. )
  98. #
  99. # if report_cer or report_wer:
  100. # self.error_calculator = ErrorCalculator(
  101. # token_list, sym_space, sym_blank, report_cer, report_wer
  102. # )
  103. #
  104. self.error_calculator = None
  105. if ctc_weight == 0.0:
  106. self.ctc = None
  107. else:
  108. self.ctc = ctc
  109. self.share_embedding = share_embedding
  110. if self.share_embedding:
  111. self.decoder.embed = None
  112. self.length_normalized_loss = length_normalized_loss
  113. self.beam_search = None
  114. def forward(
  115. self,
  116. speech: torch.Tensor,
  117. speech_lengths: torch.Tensor,
  118. text: torch.Tensor,
  119. text_lengths: torch.Tensor,
  120. **kwargs,
  121. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  122. """Encoder + Decoder + Calc loss
  123. Args:
  124. speech: (Batch, Length, ...)
  125. speech_lengths: (Batch, )
  126. text: (Batch, Length)
  127. text_lengths: (Batch,)
  128. """
  129. # import pdb;
  130. # pdb.set_trace()
  131. if len(text_lengths.size()) > 1:
  132. text_lengths = text_lengths[:, 0]
  133. if len(speech_lengths.size()) > 1:
  134. speech_lengths = speech_lengths[:, 0]
  135. batch_size = speech.shape[0]
  136. # 1. Encoder
  137. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  138. intermediate_outs = None
  139. if isinstance(encoder_out, tuple):
  140. intermediate_outs = encoder_out[1]
  141. encoder_out = encoder_out[0]
  142. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  143. loss_ctc, cer_ctc = None, None
  144. stats = dict()
  145. # decoder: CTC branch
  146. if self.ctc_weight != 0.0:
  147. loss_ctc, cer_ctc = self._calc_ctc_loss(
  148. encoder_out, encoder_out_lens, text, text_lengths
  149. )
  150. # Collect CTC branch stats
  151. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  152. stats["cer_ctc"] = cer_ctc
  153. # Intermediate CTC (optional)
  154. loss_interctc = 0.0
  155. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  156. for layer_idx, intermediate_out in intermediate_outs:
  157. # we assume intermediate_out has the same length & padding
  158. # as those of encoder_out
  159. loss_ic, cer_ic = self._calc_ctc_loss(
  160. intermediate_out, encoder_out_lens, text, text_lengths
  161. )
  162. loss_interctc = loss_interctc + loss_ic
  163. # Collect Intermedaite CTC stats
  164. stats["loss_interctc_layer{}".format(layer_idx)] = (
  165. loss_ic.detach() if loss_ic is not None else None
  166. )
  167. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  168. loss_interctc = loss_interctc / len(intermediate_outs)
  169. # calculate whole encoder loss
  170. loss_ctc = (
  171. 1 - self.interctc_weight
  172. ) * loss_ctc + self.interctc_weight * loss_interctc
  173. # decoder: Attention decoder branch
  174. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  175. encoder_out, encoder_out_lens, text, text_lengths
  176. )
  177. # 3. CTC-Att loss definition
  178. if self.ctc_weight == 0.0:
  179. loss = loss_att
  180. elif self.ctc_weight == 1.0:
  181. loss = loss_ctc
  182. else:
  183. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  184. # Collect Attn branch stats
  185. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  186. stats["acc"] = acc_att
  187. stats["cer"] = cer_att
  188. stats["wer"] = wer_att
  189. # Collect total loss stats
  190. stats["loss"] = torch.clone(loss.detach())
  191. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  192. if self.length_normalized_loss:
  193. batch_size = int((text_lengths + 1).sum())
  194. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  195. return loss, stats, weight
  196. def encode(
  197. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  198. ) -> Tuple[torch.Tensor, torch.Tensor]:
  199. """Frontend + Encoder. Note that this method is used by asr_inference.py
  200. Args:
  201. speech: (Batch, Length, ...)
  202. speech_lengths: (Batch, )
  203. ind: int
  204. """
  205. with autocast(False):
  206. # Data augmentation
  207. if self.specaug is not None and self.training:
  208. speech, speech_lengths = self.specaug(speech, speech_lengths)
  209. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  210. if self.normalize is not None:
  211. speech, speech_lengths = self.normalize(speech, speech_lengths)
  212. # Forward encoder
  213. # feats: (Batch, Length, Dim)
  214. # -> encoder_out: (Batch, Length2, Dim2)
  215. if self.encoder.interctc_use_conditioning:
  216. encoder_out, encoder_out_lens, _ = self.encoder(
  217. speech, speech_lengths, ctc=self.ctc
  218. )
  219. else:
  220. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  221. intermediate_outs = None
  222. if isinstance(encoder_out, tuple):
  223. intermediate_outs = encoder_out[1]
  224. encoder_out = encoder_out[0]
  225. if intermediate_outs is not None:
  226. return (encoder_out, intermediate_outs), encoder_out_lens
  227. return encoder_out, encoder_out_lens
  228. def _calc_att_loss(
  229. self,
  230. encoder_out: torch.Tensor,
  231. encoder_out_lens: torch.Tensor,
  232. ys_pad: torch.Tensor,
  233. ys_pad_lens: torch.Tensor,
  234. ):
  235. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  236. ys_in_lens = ys_pad_lens + 1
  237. # 1. Forward decoder
  238. decoder_out, _ = self.decoder(
  239. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  240. )
  241. # 2. Compute attention loss
  242. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  243. acc_att = th_accuracy(
  244. decoder_out.view(-1, self.vocab_size),
  245. ys_out_pad,
  246. ignore_label=self.ignore_id,
  247. )
  248. # Compute cer/wer using attention-decoder
  249. if self.training or self.error_calculator is None:
  250. cer_att, wer_att = None, None
  251. else:
  252. ys_hat = decoder_out.argmax(dim=-1)
  253. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  254. return loss_att, acc_att, cer_att, wer_att
  255. def _calc_ctc_loss(
  256. self,
  257. encoder_out: torch.Tensor,
  258. encoder_out_lens: torch.Tensor,
  259. ys_pad: torch.Tensor,
  260. ys_pad_lens: torch.Tensor,
  261. ):
  262. # Calc CTC loss
  263. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  264. # Calc CER using CTC
  265. cer_ctc = None
  266. if not self.training and self.error_calculator is not None:
  267. ys_hat = self.ctc.argmax(encoder_out).data
  268. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  269. return loss_ctc, cer_ctc
  270. def init_beam_search(self,
  271. **kwargs,
  272. ):
  273. from funasr.models.transformer.search import BeamSearch
  274. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  275. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  276. # 1. Build ASR model
  277. scorers = {}
  278. if self.ctc != None:
  279. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  280. scorers.update(
  281. ctc=ctc
  282. )
  283. token_list = kwargs.get("token_list")
  284. scorers.update(
  285. decoder=self.decoder,
  286. length_bonus=LengthBonus(len(token_list)),
  287. )
  288. # 3. Build ngram model
  289. # ngram is not supported now
  290. ngram = None
  291. scorers["ngram"] = ngram
  292. weights = dict(
  293. decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
  294. ctc=kwargs.get("decoding_ctc_weight", 0.5),
  295. lm=kwargs.get("lm_weight", 0.0),
  296. ngram=kwargs.get("ngram_weight", 0.0),
  297. length_bonus=kwargs.get("penalty", 0.0),
  298. )
  299. beam_search = BeamSearch(
  300. beam_size=kwargs.get("beam_size", 10),
  301. weights=weights,
  302. scorers=scorers,
  303. sos=self.sos,
  304. eos=self.eos,
  305. vocab_size=len(token_list),
  306. token_list=token_list,
  307. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  308. )
  309. self.beam_search = beam_search
  310. def inference(self,
  311. data_in,
  312. data_lengths=None,
  313. key: list=None,
  314. tokenizer=None,
  315. frontend=None,
  316. **kwargs,
  317. ):
  318. if kwargs.get("batch_size", 1) > 1:
  319. raise NotImplementedError("batch decoding is not implemented")
  320. # init beamsearch
  321. if self.beam_search is None:
  322. logging.info("enable beam_search")
  323. self.init_beam_search(**kwargs)
  324. self.nbest = kwargs.get("nbest", 1)
  325. meta_data = {}
  326. if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
  327. speech, speech_lengths = data_in, data_lengths
  328. if len(speech.shape) < 3:
  329. speech = speech[None, :, :]
  330. if speech_lengths is None:
  331. speech_lengths = speech.shape[1]
  332. else:
  333. # extract fbank feats
  334. time1 = time.perf_counter()
  335. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  336. data_type=kwargs.get("data_type", "sound"),
  337. tokenizer=tokenizer)
  338. time2 = time.perf_counter()
  339. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  340. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  341. frontend=frontend)
  342. time3 = time.perf_counter()
  343. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  344. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  345. speech = speech.to(device=kwargs["device"])
  346. speech_lengths = speech_lengths.to(device=kwargs["device"])
  347. # Encoder
  348. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  349. if isinstance(encoder_out, tuple):
  350. encoder_out = encoder_out[0]
  351. # c. Passed the encoder result and the beam search
  352. nbest_hyps = self.beam_search(
  353. x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
  354. )
  355. nbest_hyps = nbest_hyps[: self.nbest]
  356. results = []
  357. b, n, d = encoder_out.size()
  358. for i in range(b):
  359. for nbest_idx, hyp in enumerate(nbest_hyps):
  360. ibest_writer = None
  361. if kwargs.get("output_dir") is not None:
  362. if not hasattr(self, "writer"):
  363. self.writer = DatadirWriter(kwargs.get("output_dir"))
  364. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  365. # remove sos/eos and get results
  366. last_pos = -1
  367. if isinstance(hyp.yseq, list):
  368. token_int = hyp.yseq[1:last_pos]
  369. else:
  370. token_int = hyp.yseq[1:last_pos].tolist()
  371. # remove blank symbol id, which is assumed to be 0
  372. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  373. # Change integer-ids to tokens
  374. token = tokenizer.ids2tokens(token_int)
  375. text = tokenizer.tokens2text(token)
  376. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  377. result_i = {"key": key[i], "token": token, "text": text_postprocessed}
  378. results.append(result_i)
  379. if ibest_writer is not None:
  380. ibest_writer["token"][key[i]] = " ".join(token)
  381. ibest_writer["text"][key[i]] = text_postprocessed
  382. return results, meta_data
  383. @tables.register("model_classes", "OpenAIWhisperLIDModel")
  384. class OpenAIWhisperLIDModel(nn.Module):
  385. """WhisperEncoder and EResNet based LID Model"""
  386. def __init__(
  387. self,
  388. vocab_size: int,
  389. specaug: str = None,
  390. specaug_conf: dict = None,
  391. encoder: str = None,
  392. encoder_conf: dict = None,
  393. lid_predictor: str = None,
  394. lid_predictor_conf: dict = None,
  395. proj_dim: int = None,
  396. clip_frames: int = None,
  397. random_clip: bool = False,
  398. **kwargs,
  399. ):
  400. super().__init__()
  401. if specaug is not None:
  402. specaug_class = tables.specaug_classes.get(specaug)
  403. specaug = specaug_class(**specaug_conf)
  404. encoder_class = tables.encoder_classes.get(encoder)
  405. encoder = encoder_class(**encoder_conf)
  406. lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
  407. lid_predictor = lid_predictor_class(**lid_predictor_conf)
  408. if encoder.output_size() != proj_dim:
  409. self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim)
  410. else:
  411. self.proj_layer = None
  412. self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
  413. self.criterion_lid = LabelSmoothingLoss(
  414. size=vocab_size,
  415. padding_idx=-1,
  416. smoothing=0.0,
  417. normalize_length=False,
  418. )
  419. self.specaug = specaug
  420. self.encoder = encoder
  421. self.lid_predictor = lid_predictor
  422. self.clip_frames = clip_frames
  423. self.random_clip = random_clip
  424. self.normalize = None
  425. self.beam_search = None
  426. if not hasattr(self.encoder, "interctc_use_conditioning"):
  427. self.encoder.interctc_use_conditioning = False
  428. def forward(self,
  429. speech: torch.Tensor, # may be padding
  430. speech_lengths: torch.Tensor, # actual length
  431. lid: torch.Tensor, # lid label, (batch_size, 1)
  432. lid_lengths: torch.Tensor,
  433. ):
  434. assert lid.shape[1] == 1
  435. batch_size = speech.shape[0]
  436. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  437. # re-generate encoder_out
  438. if self.clip_frames is None:
  439. reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
  440. for i, enc_length in enumerate(encoder_out_lens):
  441. reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
  442. else:
  443. reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
  444. if self.random_clip:
  445. for i, enc_length in enumerate(encoder_out_lens):
  446. if enc_length <= self.clip_frames:
  447. reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
  448. encoder_out_lens[i] = enc_length
  449. else:
  450. max_start_index = enc_length.item() - self.clip_frames
  451. start_index = np.random.randint(0, max_start_index + 1)
  452. reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
  453. encoder_out_lens[i] = self.clip_frames
  454. else:
  455. for i, enc_length in enumerate(encoder_out_lens):
  456. enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
  457. reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
  458. encoder_out_lens[i] = enc_length
  459. if self.proj_layer is not None:
  460. reduced_encoder_out = self.proj_layer(reduced_encoder_out)
  461. lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D)
  462. lid_logits = self.output_layer(lid_output) # (B, num_classes)
  463. loss = self.criterion_lid(lid_logits[:, None, :], lid)
  464. with torch.no_grad():
  465. _, predicted_lid = torch.max(lid_logits, 1)
  466. correct = (predicted_lid == lid[:, 0]).sum().item()
  467. lid_acc = correct * 1.0 / lid_logits.shape[0]
  468. stats = dict()
  469. stats["batch_size"] = batch_size
  470. stats["loss"] = torch.clone(loss.detach())
  471. stats["acc"] = lid_acc
  472. stats["token_length"] = speech_lengths.max()
  473. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  474. return loss, stats, weight
  475. def encode(
  476. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  477. ) -> Tuple[torch.Tensor, torch.Tensor]:
  478. """Frontend + Encoder. Note that this method is used by asr_inference.py
  479. Args:
  480. speech: (Batch, Length, ...)
  481. speech_lengths: (Batch, )
  482. """
  483. with autocast(False):
  484. # Data augmentation
  485. if self.specaug is not None and self.training:
  486. speech = speech.permute(0, 2, 1)
  487. # suit for whisper padding
  488. padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
  489. speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
  490. speech = speech.permute(0, 2, 1)
  491. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  492. if self.normalize is not None:
  493. speech, speech_lengths = self.normalize(speech, speech_lengths)
  494. # Forward encoder
  495. # feats: (Batch, Length, Dim)
  496. # -> encoder_out: (Batch, Length2, Dim2)
  497. if self.encoder.interctc_use_conditioning:
  498. encoder_out, encoder_out_lens, _ = self.encoder(
  499. speech, speech_lengths, ctc=self.ctc
  500. )
  501. else:
  502. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  503. intermediate_outs = None
  504. if isinstance(encoder_out, tuple):
  505. intermediate_outs = encoder_out[1]
  506. encoder_out = encoder_out[0]
  507. if intermediate_outs is not None:
  508. return (encoder_out, intermediate_outs), encoder_out_lens
  509. return encoder_out, encoder_out_lens
  510. def inference(self,
  511. data_in,
  512. data_lengths=None,
  513. key: list = None,
  514. tokenizer=None,
  515. frontend=None,
  516. **kwargs,
  517. ):
  518. if kwargs.get("batch_size", 1) > 1:
  519. raise NotImplementedError("batch decoding is not implemented")
  520. meta_data = {}
  521. if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
  522. speech, speech_lengths = data_in, data_lengths
  523. if len(speech.shape) < 3:
  524. speech = speech[None, :, :]
  525. if speech_lengths is None:
  526. speech_lengths = speech.shape[1]
  527. else:
  528. # extract fbank feats
  529. time1 = time.perf_counter()
  530. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  531. data_type=kwargs.get("data_type", "sound"),
  532. tokenizer=tokenizer)
  533. time2 = time.perf_counter()
  534. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  535. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  536. frontend=frontend)
  537. time3 = time.perf_counter()
  538. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  539. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  540. speech = speech.to(device=kwargs["device"])
  541. speech_lengths = speech_lengths.to(device=kwargs["device"])
  542. # Encoder
  543. enc, enc_out_lens = self.encode(speech, speech_lengths)
  544. inference_clip_length = kwargs.get("inference_clip_length", None)
  545. if self.clip_frames is not None:
  546. if inference_clip_length is None:
  547. reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
  548. for i, enc_length in enumerate(enc_out_lens):
  549. enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
  550. reduced_enc[i, :enc_length] = enc[i, :enc_length]
  551. enc_out_lens[i] = enc_length
  552. else:
  553. assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
  554. reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
  555. for i, enc_length in enumerate(enc_out_lens):
  556. enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
  557. reduced_enc[i, :enc_length] = enc[i, :enc_length]
  558. enc_out_lens[i] = enc_length
  559. else:
  560. reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
  561. for i, enc_length in enumerate(enc_out_lens):
  562. reduced_enc[i, :enc_length] = enc[i, :enc_length]
  563. if self.proj_layer is not None:
  564. reduced_enc = self.proj_layer(reduced_enc)
  565. lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D)
  566. lid_logits = self.output_layer(lid_output) # (B, num_classes)
  567. _, predicted_lid_index = torch.max(lid_logits, 1)
  568. predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
  569. if kwargs.get("output_dir") is not None:
  570. if not hasattr(self, "writer"):
  571. self.writer = DatadirWriter(kwargs.get("output_dir"))
  572. lid_writer = self.writer["lid"]
  573. lid_writer[key[0]] = predicted_lid
  574. results = [{"key": key[0], "lid": predicted_lid}]
  575. return results, meta_data