beam_search.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. """Beam search module."""
  2. from itertools import chain
  3. import logging
  4. from typing import Any
  5. from typing import Dict
  6. from typing import List
  7. from typing import NamedTuple
  8. from typing import Tuple
  9. from typing import Union
  10. import torch
  11. from funasr.metrics.common import end_detect
  12. from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
  13. from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
  14. class Hypothesis(NamedTuple):
  15. """Hypothesis data type."""
  16. yseq: torch.Tensor
  17. score: Union[float, torch.Tensor] = 0
  18. scores: Dict[str, Union[float, torch.Tensor]] = dict()
  19. states: Dict[str, Any] = dict()
  20. def asdict(self) -> dict:
  21. """Convert data to JSON-friendly dict."""
  22. return self._replace(
  23. yseq=self.yseq.tolist(),
  24. score=float(self.score),
  25. scores={k: float(v) for k, v in self.scores.items()},
  26. )._asdict()
  27. class BeamSearchScama(torch.nn.Module):
  28. """Beam search implementation."""
  29. def __init__(
  30. self,
  31. scorers: Dict[str, ScorerInterface],
  32. weights: Dict[str, float],
  33. beam_size: int,
  34. vocab_size: int,
  35. sos: int,
  36. eos: int,
  37. token_list: List[str] = None,
  38. pre_beam_ratio: float = 1.5,
  39. pre_beam_score_key: str = None,
  40. ):
  41. """Initialize beam search.
  42. Args:
  43. scorers (dict[str, ScorerInterface]): Dict of decoder modules
  44. e.g., Decoder, CTCPrefixScorer, LM
  45. The scorer will be ignored if it is `None`
  46. weights (dict[str, float]): Dict of weights for each scorers
  47. The scorer will be ignored if its weight is 0
  48. beam_size (int): The number of hypotheses kept during search
  49. vocab_size (int): The number of vocabulary
  50. sos (int): Start of sequence id
  51. eos (int): End of sequence id
  52. token_list (list[str]): List of tokens for debug log
  53. pre_beam_score_key (str): key of scores to perform pre-beam search
  54. pre_beam_ratio (float): beam size in the pre-beam search
  55. will be `int(pre_beam_ratio * beam_size)`
  56. """
  57. super().__init__()
  58. # set scorers
  59. self.weights = weights
  60. self.scorers = dict()
  61. self.full_scorers = dict()
  62. self.part_scorers = dict()
  63. # this module dict is required for recursive cast
  64. # `self.to(device, dtype)` in `recog.py`
  65. self.nn_dict = torch.nn.ModuleDict()
  66. for k, v in scorers.items():
  67. w = weights.get(k, 0)
  68. if w == 0 or v is None:
  69. continue
  70. assert isinstance(
  71. v, ScorerInterface
  72. ), f"{k} ({type(v)}) does not implement ScorerInterface"
  73. self.scorers[k] = v
  74. if isinstance(v, PartialScorerInterface):
  75. self.part_scorers[k] = v
  76. else:
  77. self.full_scorers[k] = v
  78. if isinstance(v, torch.nn.Module):
  79. self.nn_dict[k] = v
  80. # set configurations
  81. self.sos = sos
  82. self.eos = eos
  83. self.token_list = token_list
  84. self.pre_beam_size = int(pre_beam_ratio * beam_size)
  85. self.beam_size = beam_size
  86. self.n_vocab = vocab_size
  87. if (
  88. pre_beam_score_key is not None
  89. and pre_beam_score_key != "full"
  90. and pre_beam_score_key not in self.full_scorers
  91. ):
  92. raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
  93. self.pre_beam_score_key = pre_beam_score_key
  94. self.do_pre_beam = (
  95. self.pre_beam_score_key is not None
  96. and self.pre_beam_size < self.n_vocab
  97. and len(self.part_scorers) > 0
  98. )
  99. def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
  100. """Get an initial hypothesis data.
  101. Args:
  102. x (torch.Tensor): The encoder output feature
  103. Returns:
  104. Hypothesis: The initial hypothesis.
  105. """
  106. init_states = dict()
  107. init_scores = dict()
  108. for k, d in self.scorers.items():
  109. init_states[k] = d.init_state(x)
  110. init_scores[k] = 0.0
  111. return [
  112. Hypothesis(
  113. score=0.0,
  114. scores=init_scores,
  115. states=init_states,
  116. yseq=torch.tensor([self.sos], device=x.device),
  117. )
  118. ]
  119. @staticmethod
  120. def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
  121. """Append new token to prefix tokens.
  122. Args:
  123. xs (torch.Tensor): The prefix token
  124. x (int): The new token to append
  125. Returns:
  126. torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
  127. """
  128. x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
  129. return torch.cat((xs, x))
  130. def score_full(
  131. self, hyp: Hypothesis,
  132. x: torch.Tensor,
  133. x_mask: torch.Tensor = None,
  134. pre_acoustic_embeds: torch.Tensor = None,
  135. ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
  136. """Score new hypothesis by `self.full_scorers`.
  137. Args:
  138. hyp (Hypothesis): Hypothesis with prefix tokens to score
  139. x (torch.Tensor): Corresponding input feature
  140. Returns:
  141. Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
  142. score dict of `hyp` that has string keys of `self.full_scorers`
  143. and tensor score values of shape: `(self.n_vocab,)`,
  144. and state dict that has string keys
  145. and state values of `self.full_scorers`
  146. """
  147. scores = dict()
  148. states = dict()
  149. for k, d in self.full_scorers.items():
  150. scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
  151. return scores, states
  152. def score_partial(
  153. self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
  154. ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
  155. """Score new hypothesis by `self.part_scorers`.
  156. Args:
  157. hyp (Hypothesis): Hypothesis with prefix tokens to score
  158. ids (torch.Tensor): 1D tensor of new partial tokens to score
  159. x (torch.Tensor): Corresponding input feature
  160. Returns:
  161. Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
  162. score dict of `hyp` that has string keys of `self.part_scorers`
  163. and tensor score values of shape: `(len(ids),)`,
  164. and state dict that has string keys
  165. and state values of `self.part_scorers`
  166. """
  167. scores = dict()
  168. states = dict()
  169. for k, d in self.part_scorers.items():
  170. scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
  171. return scores, states
  172. def beam(
  173. self, weighted_scores: torch.Tensor, ids: torch.Tensor
  174. ) -> Tuple[torch.Tensor, torch.Tensor]:
  175. """Compute topk full token ids and partial token ids.
  176. Args:
  177. weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
  178. Its shape is `(self.n_vocab,)`.
  179. ids (torch.Tensor): The partial token ids to compute topk
  180. Returns:
  181. Tuple[torch.Tensor, torch.Tensor]:
  182. The topk full token ids and partial token ids.
  183. Their shapes are `(self.beam_size,)`
  184. """
  185. # no pre beam performed
  186. if weighted_scores.size(0) == ids.size(0):
  187. top_ids = weighted_scores.topk(self.beam_size)[1]
  188. return top_ids, top_ids
  189. # mask pruned in pre-beam not to select in topk
  190. tmp = weighted_scores[ids]
  191. weighted_scores[:] = -float("inf")
  192. weighted_scores[ids] = tmp
  193. top_ids = weighted_scores.topk(self.beam_size)[1]
  194. local_ids = weighted_scores[ids].topk(self.beam_size)[1]
  195. return top_ids, local_ids
  196. @staticmethod
  197. def merge_scores(
  198. prev_scores: Dict[str, float],
  199. next_full_scores: Dict[str, torch.Tensor],
  200. full_idx: int,
  201. next_part_scores: Dict[str, torch.Tensor],
  202. part_idx: int,
  203. ) -> Dict[str, torch.Tensor]:
  204. """Merge scores for new hypothesis.
  205. Args:
  206. prev_scores (Dict[str, float]):
  207. The previous hypothesis scores by `self.scorers`
  208. next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
  209. full_idx (int): The next token id for `next_full_scores`
  210. next_part_scores (Dict[str, torch.Tensor]):
  211. scores of partial tokens by `self.part_scorers`
  212. part_idx (int): The new token id for `next_part_scores`
  213. Returns:
  214. Dict[str, torch.Tensor]: The new score dict.
  215. Its keys are names of `self.full_scorers` and `self.part_scorers`.
  216. Its values are scalar tensors by the scorers.
  217. """
  218. new_scores = dict()
  219. for k, v in next_full_scores.items():
  220. new_scores[k] = prev_scores[k] + v[full_idx]
  221. for k, v in next_part_scores.items():
  222. new_scores[k] = prev_scores[k] + v[part_idx]
  223. return new_scores
  224. def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
  225. """Merge states for new hypothesis.
  226. Args:
  227. states: states of `self.full_scorers`
  228. part_states: states of `self.part_scorers`
  229. part_idx (int): The new token id for `part_scores`
  230. Returns:
  231. Dict[str, torch.Tensor]: The new score dict.
  232. Its keys are names of `self.full_scorers` and `self.part_scorers`.
  233. Its values are states of the scorers.
  234. """
  235. new_states = dict()
  236. for k, v in states.items():
  237. new_states[k] = v
  238. for k, d in self.part_scorers.items():
  239. new_states[k] = d.select_state(part_states[k], part_idx)
  240. return new_states
  241. def search(
  242. self, running_hyps: List[Hypothesis],
  243. x: torch.Tensor,
  244. x_mask: torch.Tensor = None,
  245. pre_acoustic_embeds: torch.Tensor = None,
  246. ) -> List[Hypothesis]:
  247. """Search new tokens for running hypotheses and encoded speech x.
  248. Args:
  249. running_hyps (List[Hypothesis]): Running hypotheses on beam
  250. x (torch.Tensor): Encoded speech feature (T, D)
  251. Returns:
  252. List[Hypotheses]: Best sorted hypotheses
  253. """
  254. best_hyps = []
  255. part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
  256. for hyp in running_hyps:
  257. # scoring
  258. weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
  259. scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
  260. for k in self.full_scorers:
  261. weighted_scores += self.weights[k] * scores[k]
  262. # partial scoring
  263. if self.do_pre_beam:
  264. pre_beam_scores = (
  265. weighted_scores
  266. if self.pre_beam_score_key == "full"
  267. else scores[self.pre_beam_score_key]
  268. )
  269. part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
  270. part_scores, part_states = self.score_partial(hyp, part_ids, x)
  271. for k in self.part_scorers:
  272. weighted_scores[part_ids] += self.weights[k] * part_scores[k]
  273. # add previous hyp score
  274. weighted_scores += hyp.score
  275. # update hyps
  276. for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
  277. # will be (2 x beam at most)
  278. best_hyps.append(
  279. Hypothesis(
  280. score=weighted_scores[j],
  281. yseq=self.append_token(hyp.yseq, j),
  282. scores=self.merge_scores(
  283. hyp.scores, scores, j, part_scores, part_j
  284. ),
  285. states=self.merge_states(states, part_states, part_j),
  286. )
  287. )
  288. # sort and prune 2 x beam -> beam
  289. best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
  290. : min(len(best_hyps), self.beam_size)
  291. ]
  292. return best_hyps
  293. def forward(
  294. self, x: torch.Tensor,
  295. scama_mask: torch.Tensor = None,
  296. pre_acoustic_embeds: torch.Tensor = None,
  297. maxlenratio: float = 0.0,
  298. minlenratio: float = 0.0,
  299. maxlen: int = None,
  300. minlen: int = 0,
  301. ) -> List[Hypothesis]:
  302. """Perform beam search.
  303. Args:
  304. x (torch.Tensor): Encoded speech feature (T, D)
  305. maxlenratio (float): Input length ratio to obtain max output length.
  306. If maxlenratio=0.0 (default), it uses a end-detect function
  307. to automatically find maximum hypothesis lengths
  308. If maxlenratio<0.0, its absolute value is interpreted
  309. as a constant max output length.
  310. minlenratio (float): Input length ratio to obtain min output length.
  311. Returns:
  312. list[Hypothesis]: N-best decoding results
  313. """
  314. if maxlen is None:
  315. # set length bounds
  316. if maxlenratio == 0:
  317. maxlen = x.shape[0]
  318. elif maxlenratio < 0:
  319. maxlen = -1 * int(maxlenratio)
  320. else:
  321. maxlen = max(1, int(maxlenratio * x.size(0)))
  322. minlen = int(minlenratio * x.size(0))
  323. logging.info("decoder input length: " + str(x.shape[0]))
  324. logging.info("max output length: " + str(maxlen))
  325. logging.info("min output length: " + str(minlen))
  326. # main loop of prefix search
  327. running_hyps = self.init_hyp(x)
  328. ended_hyps = []
  329. for i in range(maxlen):
  330. logging.debug("position " + str(i))
  331. mask_enc = None
  332. if scama_mask is not None:
  333. token_num_predictor = scama_mask.size(1)
  334. token_id_slice = min(i, token_num_predictor-1)
  335. mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
  336. # if mask_enc.size(1) == 0:
  337. # mask_enc = scama_mask[:, -2:-1, :]
  338. # # mask_enc = torch.zeros_like(mask_enc)
  339. pre_acoustic_embeds_cur = None
  340. if pre_acoustic_embeds is not None:
  341. b, t, d = pre_acoustic_embeds.size()
  342. pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
  343. pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
  344. token_id_slice = min(i, t)
  345. pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
  346. best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
  347. # post process of one iteration
  348. running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
  349. # end detection
  350. if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
  351. logging.info(f"end detected at {i}")
  352. break
  353. if len(running_hyps) == 0:
  354. logging.info("no hypothesis. Finish decoding.")
  355. break
  356. else:
  357. logging.debug(f"remained hypotheses: {len(running_hyps)}")
  358. nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
  359. # check the number of hypotheses reaching to eos
  360. if len(nbest_hyps) == 0:
  361. logging.warning(
  362. "there is no N-best results, perform recognition "
  363. "again with smaller minlenratio."
  364. )
  365. return (
  366. []
  367. if minlenratio < 0.1
  368. else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
  369. )
  370. # report the best result
  371. for x in nbest_hyps:
  372. yseq = "".join([self.token_list[x] for x in x.yseq])
  373. logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
  374. best = nbest_hyps[0]
  375. for k, v in best.scores.items():
  376. logging.info(
  377. f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
  378. )
  379. logging.info(f"total log probability: {best.score:.2f}")
  380. logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
  381. logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
  382. if self.token_list is not None:
  383. logging.info(
  384. "best hypo: "
  385. + "".join([self.token_list[x] for x in best.yseq[1:-1]])
  386. + "\n"
  387. )
  388. return nbest_hyps
  389. def post_process(
  390. self,
  391. i: int,
  392. maxlen: int,
  393. maxlenratio: float,
  394. running_hyps: List[Hypothesis],
  395. ended_hyps: List[Hypothesis],
  396. ) -> List[Hypothesis]:
  397. """Perform post-processing of beam search iterations.
  398. Args:
  399. i (int): The length of hypothesis tokens.
  400. maxlen (int): The maximum length of tokens in beam search.
  401. maxlenratio (int): The maximum length ratio in beam search.
  402. running_hyps (List[Hypothesis]): The running hypotheses in beam search.
  403. ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
  404. Returns:
  405. List[Hypothesis]: The new running hypotheses.
  406. """
  407. logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
  408. if self.token_list is not None:
  409. logging.debug(
  410. "best hypo: "
  411. + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
  412. )
  413. # add eos in the final loop to avoid that there are no ended hyps
  414. if i == maxlen - 1:
  415. logging.info("adding <eos> in the last position in the loop")
  416. running_hyps = [
  417. h._replace(yseq=self.append_token(h.yseq, self.eos))
  418. for h in running_hyps
  419. ]
  420. # add ended hypotheses to a final list, and removed them from current hypotheses
  421. # (this will be a problem, number of hyps < beam)
  422. remained_hyps = []
  423. for hyp in running_hyps:
  424. if hyp.yseq[-1] == self.eos:
  425. # e.g., Word LM needs to add final <eos> score
  426. for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
  427. s = d.final_score(hyp.states[k])
  428. hyp.scores[k] += s
  429. hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
  430. ended_hyps.append(hyp)
  431. else:
  432. remained_hyps.append(hyp)
  433. return remained_hyps