model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. import re
  7. import time
  8. import torch
  9. import codecs
  10. import logging
  11. import tempfile
  12. import requests
  13. import numpy as np
  14. from typing import Dict, Tuple
  15. from contextlib import contextmanager
  16. from distutils.version import LooseVersion
  17. from funasr.register import tables
  18. from funasr.losses.label_smoothing_loss import (
  19. LabelSmoothingLoss, # noqa: H301
  20. )
  21. from funasr.utils import postprocess_utils
  22. from funasr.metrics.compute_acc import th_accuracy
  23. from funasr.models.paraformer.model import Paraformer
  24. from funasr.utils.datadir_writer import DatadirWriter
  25. from funasr.models.paraformer.search import Hypothesis
  26. from funasr.train_utils.device_funcs import force_gatherable
  27. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  28. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  29. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  30. import pdb
  31. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  32. from torch.cuda.amp import autocast
  33. else:
  34. # Nothing to do if torch<1.6.0
  35. @contextmanager
  36. def autocast(enabled=True):
  37. yield
  38. @tables.register("model_classes", "ContextualParaformer")
  39. class ContextualParaformer(Paraformer):
  40. """
  41. Author: Speech Lab of DAMO Academy, Alibaba Group
  42. FunASR: A Fundamental End-to-End Speech Recognition Toolkit
  43. https://arxiv.org/abs/2305.11013
  44. """
  45. def __init__(
  46. self,
  47. *args,
  48. **kwargs,
  49. ):
  50. super().__init__(*args, **kwargs)
  51. self.target_buffer_length = kwargs.get("target_buffer_length", -1)
  52. inner_dim = kwargs.get("inner_dim", 256)
  53. bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
  54. use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
  55. crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
  56. crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
  57. bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
  58. if bias_encoder_type == 'lstm':
  59. self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
  60. self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
  61. elif bias_encoder_type == 'mean':
  62. self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
  63. else:
  64. logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
  65. if self.target_buffer_length > 0:
  66. self.hotword_buffer = None
  67. self.length_record = []
  68. self.current_buffer_length = 0
  69. self.use_decoder_embedding = use_decoder_embedding
  70. self.crit_attn_weight = crit_attn_weight
  71. if self.crit_attn_weight > 0:
  72. self.attn_loss = torch.nn.L1Loss()
  73. self.crit_attn_smooth = crit_attn_smooth
  74. def forward(
  75. self,
  76. speech: torch.Tensor,
  77. speech_lengths: torch.Tensor,
  78. text: torch.Tensor,
  79. text_lengths: torch.Tensor,
  80. **kwargs,
  81. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  82. """Frontend + Encoder + Decoder + Calc loss
  83. Args:
  84. speech: (Batch, Length, ...)
  85. speech_lengths: (Batch, )
  86. text: (Batch, Length)
  87. text_lengths: (Batch,)
  88. """
  89. if len(text_lengths.size()) > 1:
  90. text_lengths = text_lengths[:, 0]
  91. if len(speech_lengths.size()) > 1:
  92. speech_lengths = speech_lengths[:, 0]
  93. batch_size = speech.shape[0]
  94. hotword_pad = kwargs.get("hotword_pad")
  95. hotword_lengths = kwargs.get("hotword_lengths")
  96. dha_pad = kwargs.get("dha_pad")
  97. # 1. Encoder
  98. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  99. loss_ctc, cer_ctc = None, None
  100. stats = dict()
  101. # 1. CTC branch
  102. if self.ctc_weight != 0.0:
  103. loss_ctc, cer_ctc = self._calc_ctc_loss(
  104. encoder_out, encoder_out_lens, text, text_lengths
  105. )
  106. # Collect CTC branch stats
  107. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  108. stats["cer_ctc"] = cer_ctc
  109. # 2b. Attention decoder branch
  110. loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
  111. encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
  112. )
  113. # 3. CTC-Att loss definition
  114. if self.ctc_weight == 0.0:
  115. loss = loss_att + loss_pre * self.predictor_weight
  116. else:
  117. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  118. if loss_ideal is not None:
  119. loss = loss + loss_ideal * self.crit_attn_weight
  120. stats["loss_ideal"] = loss_ideal.detach().cpu()
  121. # Collect Attn branch stats
  122. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  123. stats["acc"] = acc_att
  124. stats["cer"] = cer_att
  125. stats["wer"] = wer_att
  126. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  127. stats["loss"] = torch.clone(loss.detach())
  128. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  129. if self.length_normalized_loss:
  130. batch_size = int((text_lengths + self.predictor_bias).sum())
  131. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  132. return loss, stats, weight
  133. def _calc_att_clas_loss(
  134. self,
  135. encoder_out: torch.Tensor,
  136. encoder_out_lens: torch.Tensor,
  137. ys_pad: torch.Tensor,
  138. ys_pad_lens: torch.Tensor,
  139. hotword_pad: torch.Tensor,
  140. hotword_lengths: torch.Tensor,
  141. ):
  142. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  143. encoder_out.device)
  144. if self.predictor_bias == 1:
  145. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  146. ys_pad_lens = ys_pad_lens + self.predictor_bias
  147. pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  148. ignore_id=self.ignore_id)
  149. # -1. bias encoder
  150. if self.use_decoder_embedding:
  151. hw_embed = self.decoder.embed(hotword_pad)
  152. else:
  153. hw_embed = self.bias_embed(hotword_pad)
  154. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  155. _ind = np.arange(0, hotword_pad.shape[0]).tolist()
  156. selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
  157. contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
  158. # 0. sampler
  159. decoder_out_1st = None
  160. if self.sampling_ratio > 0.0:
  161. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  162. pre_acoustic_embeds, contextual_info)
  163. else:
  164. sematic_embeds = pre_acoustic_embeds
  165. # 1. Forward decoder
  166. decoder_outs = self.decoder(
  167. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  168. )
  169. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  170. '''
  171. if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
  172. ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
  173. attn_non_blank = attn[:,:,:,:-1]
  174. ideal_attn_non_blank = ideal_attn[:,:,:-1]
  175. loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
  176. else:
  177. loss_ideal = None
  178. '''
  179. loss_ideal = None
  180. if decoder_out_1st is None:
  181. decoder_out_1st = decoder_out
  182. # 2. Compute attention loss
  183. loss_att = self.criterion_att(decoder_out, ys_pad)
  184. acc_att = th_accuracy(
  185. decoder_out_1st.view(-1, self.vocab_size),
  186. ys_pad,
  187. ignore_label=self.ignore_id,
  188. )
  189. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  190. # Compute cer/wer using attention-decoder
  191. if self.training or self.error_calculator is None:
  192. cer_att, wer_att = None, None
  193. else:
  194. ys_hat = decoder_out_1st.argmax(dim=-1)
  195. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  196. return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
  197. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
  198. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  199. ys_pad = ys_pad * tgt_mask[:, :, 0]
  200. if self.share_embedding:
  201. ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
  202. else:
  203. ys_pad_embed = self.decoder.embed(ys_pad)
  204. with torch.no_grad():
  205. decoder_outs = self.decoder(
  206. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
  207. )
  208. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  209. pred_tokens = decoder_out.argmax(-1)
  210. nonpad_positions = ys_pad.ne(self.ignore_id)
  211. seq_lens = (nonpad_positions).sum(1)
  212. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  213. input_mask = torch.ones_like(nonpad_positions)
  214. bsz, seq_len = ys_pad.size()
  215. for li in range(bsz):
  216. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  217. if target_num > 0:
  218. input_mask[li].scatter_(dim=0,
  219. index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
  220. value=0)
  221. input_mask = input_mask.eq(1)
  222. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  223. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  224. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  225. input_mask_expand_dim, 0)
  226. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  227. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
  228. clas_scale=1.0):
  229. if hw_list is None:
  230. hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
  231. hw_list_pad = pad_list(hw_list, 0)
  232. if self.use_decoder_embedding:
  233. hw_embed = self.decoder.embed(hw_list_pad)
  234. else:
  235. hw_embed = self.bias_embed(hw_list_pad)
  236. hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
  237. hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
  238. else:
  239. hw_lengths = [len(i) for i in hw_list]
  240. hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
  241. if self.use_decoder_embedding:
  242. hw_embed = self.decoder.embed(hw_list_pad)
  243. else:
  244. hw_embed = self.bias_embed(hw_list_pad)
  245. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
  246. enforce_sorted=False)
  247. _, (h_n, _) = self.bias_encoder(hw_embed)
  248. hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
  249. decoder_outs = self.decoder(
  250. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
  251. )
  252. decoder_out = decoder_outs[0]
  253. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  254. return decoder_out, ys_pad_lens
  255. def inference(self,
  256. data_in,
  257. data_lengths=None,
  258. key: list = None,
  259. tokenizer=None,
  260. frontend=None,
  261. **kwargs,
  262. ):
  263. # init beamsearch
  264. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  265. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  266. if self.beam_search is None and (is_use_lm or is_use_ctc):
  267. logging.info("enable beam_search")
  268. self.init_beam_search(**kwargs)
  269. self.nbest = kwargs.get("nbest", 1)
  270. meta_data = {}
  271. # extract fbank feats
  272. time1 = time.perf_counter()
  273. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
  274. time2 = time.perf_counter()
  275. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  276. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  277. frontend=frontend)
  278. time3 = time.perf_counter()
  279. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  280. meta_data[
  281. "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  282. speech = speech.to(device=kwargs["device"])
  283. speech_lengths = speech_lengths.to(device=kwargs["device"])
  284. # hotword
  285. self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
  286. # Encoder
  287. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  288. if isinstance(encoder_out, tuple):
  289. encoder_out = encoder_out[0]
  290. # predictor
  291. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  292. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  293. predictor_outs[2], predictor_outs[3]
  294. pre_token_length = pre_token_length.round().long()
  295. if torch.max(pre_token_length) < 1:
  296. return []
  297. decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
  298. pre_acoustic_embeds,
  299. pre_token_length,
  300. hw_list=self.hotword_list,
  301. clas_scale=kwargs.get("clas_scale", 1.0))
  302. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  303. results = []
  304. b, n, d = decoder_out.size()
  305. for i in range(b):
  306. x = encoder_out[i, :encoder_out_lens[i], :]
  307. am_scores = decoder_out[i, :pre_token_length[i], :]
  308. if self.beam_search is not None:
  309. nbest_hyps = self.beam_search(
  310. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  311. minlenratio=kwargs.get("minlenratio", 0.0)
  312. )
  313. nbest_hyps = nbest_hyps[: self.nbest]
  314. else:
  315. yseq = am_scores.argmax(dim=-1)
  316. score = am_scores.max(dim=-1)[0]
  317. score = torch.sum(score, dim=-1)
  318. # pad with mask tokens to ensure compatibility with sos/eos tokens
  319. yseq = torch.tensor(
  320. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  321. )
  322. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  323. for nbest_idx, hyp in enumerate(nbest_hyps):
  324. ibest_writer = None
  325. if kwargs.get("output_dir") is not None:
  326. if not hasattr(self, "writer"):
  327. self.writer = DatadirWriter(kwargs.get("output_dir"))
  328. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  329. # remove sos/eos and get results
  330. last_pos = -1
  331. if isinstance(hyp.yseq, list):
  332. token_int = hyp.yseq[1:last_pos]
  333. else:
  334. token_int = hyp.yseq[1:last_pos].tolist()
  335. # remove blank symbol id, which is assumed to be 0
  336. token_int = list(
  337. filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  338. if tokenizer is not None:
  339. # Change integer-ids to tokens
  340. token = tokenizer.ids2tokens(token_int)
  341. text = tokenizer.tokens2text(token)
  342. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  343. result_i = {"key": key[i], "text": text_postprocessed}
  344. if ibest_writer is not None:
  345. ibest_writer["token"][key[i]] = " ".join(token)
  346. ibest_writer["text"][key[i]] = text
  347. ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
  348. else:
  349. result_i = {"key": key[i], "token_int": token_int}
  350. results.append(result_i)
  351. return results, meta_data
  352. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
  353. def load_seg_dict(seg_dict_file):
  354. seg_dict = {}
  355. assert isinstance(seg_dict_file, str)
  356. with open(seg_dict_file, "r", encoding="utf8") as f:
  357. lines = f.readlines()
  358. for line in lines:
  359. s = line.strip().split()
  360. key = s[0]
  361. value = s[1:]
  362. seg_dict[key] = " ".join(value)
  363. return seg_dict
  364. def seg_tokenize(txt, seg_dict):
  365. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  366. out_txt = ""
  367. for word in txt:
  368. word = word.lower()
  369. if word in seg_dict:
  370. out_txt += seg_dict[word] + " "
  371. else:
  372. if pattern.match(word):
  373. for char in word:
  374. if char in seg_dict:
  375. out_txt += seg_dict[char] + " "
  376. else:
  377. out_txt += "<unk>" + " "
  378. else:
  379. out_txt += "<unk>" + " "
  380. return out_txt.strip().split()
  381. seg_dict = None
  382. if frontend.cmvn_file is not None:
  383. model_dir = os.path.dirname(frontend.cmvn_file)
  384. seg_dict_file = os.path.join(model_dir, 'seg_dict')
  385. if os.path.exists(seg_dict_file):
  386. seg_dict = load_seg_dict(seg_dict_file)
  387. else:
  388. seg_dict = None
  389. # for None
  390. if hotword_list_or_file is None:
  391. hotword_list = None
  392. # for local txt inputs
  393. elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
  394. logging.info("Attempting to parse hotwords from local txt...")
  395. hotword_list = []
  396. hotword_str_list = []
  397. with codecs.open(hotword_list_or_file, 'r') as fin:
  398. for line in fin.readlines():
  399. hw = line.strip()
  400. hw_list = hw.split()
  401. if seg_dict is not None:
  402. hw_list = seg_tokenize(hw_list, seg_dict)
  403. hotword_str_list.append(hw)
  404. hotword_list.append(tokenizer.tokens2ids(hw_list))
  405. hotword_list.append([self.sos])
  406. hotword_str_list.append('<s>')
  407. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  408. .format(hotword_list_or_file, hotword_str_list))
  409. # for url, download and generate txt
  410. elif hotword_list_or_file.startswith('http'):
  411. logging.info("Attempting to parse hotwords from url...")
  412. work_dir = tempfile.TemporaryDirectory().name
  413. if not os.path.exists(work_dir):
  414. os.makedirs(work_dir)
  415. text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
  416. local_file = requests.get(hotword_list_or_file)
  417. open(text_file_path, "wb").write(local_file.content)
  418. hotword_list_or_file = text_file_path
  419. hotword_list = []
  420. hotword_str_list = []
  421. with codecs.open(hotword_list_or_file, 'r') as fin:
  422. for line in fin.readlines():
  423. hw = line.strip()
  424. hw_list = hw.split()
  425. if seg_dict is not None:
  426. hw_list = seg_tokenize(hw_list, seg_dict)
  427. hotword_str_list.append(hw)
  428. hotword_list.append(tokenizer.tokens2ids(hw_list))
  429. hotword_list.append([self.sos])
  430. hotword_str_list.append('<s>')
  431. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  432. .format(hotword_list_or_file, hotword_str_list))
  433. # for text str input
  434. elif not hotword_list_or_file.endswith('.txt'):
  435. logging.info("Attempting to parse hotwords as str...")
  436. hotword_list = []
  437. hotword_str_list = []
  438. for hw in hotword_list_or_file.strip().split():
  439. hotword_str_list.append(hw)
  440. hw_list = hw.strip().split()
  441. if seg_dict is not None:
  442. hw_list = seg_tokenize(hw_list, seg_dict)
  443. hotword_list.append(tokenizer.tokens2ids(hw_list))
  444. hotword_list.append([self.sos])
  445. hotword_str_list.append('<s>')
  446. logging.info("Hotword list: {}.".format(hotword_str_list))
  447. else:
  448. hotword_list = None
  449. return hotword_list