search.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. import logging
  7. from itertools import chain
  8. from typing import Any, Dict, List, NamedTuple, Tuple, Union
  9. from funasr.metrics.common import end_detect
  10. from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface, ScorerInterface
  11. class Hypothesis(NamedTuple):
  12. """Hypothesis data type."""
  13. yseq: torch.Tensor
  14. score: Union[float, torch.Tensor] = 0
  15. scores: Dict[str, Union[float, torch.Tensor]] = dict()
  16. states: Dict[str, Any] = dict()
  17. def asdict(self) -> dict:
  18. """Convert data to JSON-friendly dict."""
  19. return self._replace(
  20. yseq=self.yseq.tolist(),
  21. score=float(self.score),
  22. scores={k: float(v) for k, v in self.scores.items()},
  23. )._asdict()
  24. class BeamSearchPara(torch.nn.Module):
  25. """Beam search implementation."""
  26. def __init__(
  27. self,
  28. scorers: Dict[str, ScorerInterface],
  29. weights: Dict[str, float],
  30. beam_size: int,
  31. vocab_size: int,
  32. sos: int,
  33. eos: int,
  34. token_list: List[str] = None,
  35. pre_beam_ratio: float = 1.5,
  36. pre_beam_score_key: str = None,
  37. ):
  38. """Initialize beam search.
  39. Args:
  40. scorers (dict[str, ScorerInterface]): Dict of decoder modules
  41. e.g., Decoder, CTCPrefixScorer, LM
  42. The scorer will be ignored if it is `None`
  43. weights (dict[str, float]): Dict of weights for each scorers
  44. The scorer will be ignored if its weight is 0
  45. beam_size (int): The number of hypotheses kept during search
  46. vocab_size (int): The number of vocabulary
  47. sos (int): Start of sequence id
  48. eos (int): End of sequence id
  49. token_list (list[str]): List of tokens for debug log
  50. pre_beam_score_key (str): key of scores to perform pre-beam search
  51. pre_beam_ratio (float): beam size in the pre-beam search
  52. will be `int(pre_beam_ratio * beam_size)`
  53. """
  54. super().__init__()
  55. # set scorers
  56. self.weights = weights
  57. self.scorers = dict()
  58. self.full_scorers = dict()
  59. self.part_scorers = dict()
  60. # this module dict is required for recursive cast
  61. # `self.to(device, dtype)` in `recog.py`
  62. self.nn_dict = torch.nn.ModuleDict()
  63. for k, v in scorers.items():
  64. w = weights.get(k, 0)
  65. if w == 0 or v is None:
  66. continue
  67. assert isinstance(
  68. v, ScorerInterface
  69. ), f"{k} ({type(v)}) does not implement ScorerInterface"
  70. self.scorers[k] = v
  71. if isinstance(v, PartialScorerInterface):
  72. self.part_scorers[k] = v
  73. else:
  74. self.full_scorers[k] = v
  75. if isinstance(v, torch.nn.Module):
  76. self.nn_dict[k] = v
  77. # set configurations
  78. self.sos = sos
  79. self.eos = eos
  80. self.token_list = token_list
  81. self.pre_beam_size = int(pre_beam_ratio * beam_size)
  82. self.beam_size = beam_size
  83. self.n_vocab = vocab_size
  84. if (
  85. pre_beam_score_key is not None
  86. and pre_beam_score_key != "full"
  87. and pre_beam_score_key not in self.full_scorers
  88. ):
  89. raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
  90. self.pre_beam_score_key = pre_beam_score_key
  91. self.do_pre_beam = (
  92. self.pre_beam_score_key is not None
  93. and self.pre_beam_size < self.n_vocab
  94. and len(self.part_scorers) > 0
  95. )
  96. def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
  97. """Get an initial hypothesis data.
  98. Args:
  99. x (torch.Tensor): The encoder output feature
  100. Returns:
  101. Hypothesis: The initial hypothesis.
  102. """
  103. init_states = dict()
  104. init_scores = dict()
  105. for k, d in self.scorers.items():
  106. init_states[k] = d.init_state(x)
  107. init_scores[k] = 0.0
  108. return [
  109. Hypothesis(
  110. score=0.0,
  111. scores=init_scores,
  112. states=init_states,
  113. yseq=torch.tensor([self.sos], device=x.device),
  114. )
  115. ]
  116. @staticmethod
  117. def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
  118. """Append new token to prefix tokens.
  119. Args:
  120. xs (torch.Tensor): The prefix token
  121. x (int): The new token to append
  122. Returns:
  123. torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
  124. """
  125. x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
  126. return torch.cat((xs, x))
  127. def score_full(
  128. self, hyp: Hypothesis, x: torch.Tensor
  129. ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
  130. """Score new hypothesis by `self.full_scorers`.
  131. Args:
  132. hyp (Hypothesis): Hypothesis with prefix tokens to score
  133. x (torch.Tensor): Corresponding input feature
  134. Returns:
  135. Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
  136. score dict of `hyp` that has string keys of `self.full_scorers`
  137. and tensor score values of shape: `(self.n_vocab,)`,
  138. and state dict that has string keys
  139. and state values of `self.full_scorers`
  140. """
  141. scores = dict()
  142. states = dict()
  143. for k, d in self.full_scorers.items():
  144. scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
  145. return scores, states
  146. def score_partial(
  147. self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
  148. ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
  149. """Score new hypothesis by `self.part_scorers`.
  150. Args:
  151. hyp (Hypothesis): Hypothesis with prefix tokens to score
  152. ids (torch.Tensor): 1D tensor of new partial tokens to score
  153. x (torch.Tensor): Corresponding input feature
  154. Returns:
  155. Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
  156. score dict of `hyp` that has string keys of `self.part_scorers`
  157. and tensor score values of shape: `(len(ids),)`,
  158. and state dict that has string keys
  159. and state values of `self.part_scorers`
  160. """
  161. scores = dict()
  162. states = dict()
  163. for k, d in self.part_scorers.items():
  164. scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
  165. return scores, states
  166. def beam(
  167. self, weighted_scores: torch.Tensor, ids: torch.Tensor
  168. ) -> Tuple[torch.Tensor, torch.Tensor]:
  169. """Compute topk full token ids and partial token ids.
  170. Args:
  171. weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
  172. Its shape is `(self.n_vocab,)`.
  173. ids (torch.Tensor): The partial token ids to compute topk
  174. Returns:
  175. Tuple[torch.Tensor, torch.Tensor]:
  176. The topk full token ids and partial token ids.
  177. Their shapes are `(self.beam_size,)`
  178. """
  179. # no pre beam performed
  180. if weighted_scores.size(0) == ids.size(0):
  181. top_ids = weighted_scores.topk(self.beam_size)[1]
  182. return top_ids, top_ids
  183. # mask pruned in pre-beam not to select in topk
  184. tmp = weighted_scores[ids]
  185. weighted_scores[:] = -float("inf")
  186. weighted_scores[ids] = tmp
  187. top_ids = weighted_scores.topk(self.beam_size)[1]
  188. local_ids = weighted_scores[ids].topk(self.beam_size)[1]
  189. return top_ids, local_ids
  190. @staticmethod
  191. def merge_scores(
  192. prev_scores: Dict[str, float],
  193. next_full_scores: Dict[str, torch.Tensor],
  194. full_idx: int,
  195. next_part_scores: Dict[str, torch.Tensor],
  196. part_idx: int,
  197. ) -> Dict[str, torch.Tensor]:
  198. """Merge scores for new hypothesis.
  199. Args:
  200. prev_scores (Dict[str, float]):
  201. The previous hypothesis scores by `self.scorers`
  202. next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
  203. full_idx (int): The next token id for `next_full_scores`
  204. next_part_scores (Dict[str, torch.Tensor]):
  205. scores of partial tokens by `self.part_scorers`
  206. part_idx (int): The new token id for `next_part_scores`
  207. Returns:
  208. Dict[str, torch.Tensor]: The new score dict.
  209. Its keys are names of `self.full_scorers` and `self.part_scorers`.
  210. Its values are scalar tensors by the scorers.
  211. """
  212. new_scores = dict()
  213. for k, v in next_full_scores.items():
  214. new_scores[k] = prev_scores[k] + v[full_idx]
  215. for k, v in next_part_scores.items():
  216. new_scores[k] = prev_scores[k] + v[part_idx]
  217. return new_scores
  218. def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
  219. """Merge states for new hypothesis.
  220. Args:
  221. states: states of `self.full_scorers`
  222. part_states: states of `self.part_scorers`
  223. part_idx (int): The new token id for `part_scores`
  224. Returns:
  225. Dict[str, torch.Tensor]: The new score dict.
  226. Its keys are names of `self.full_scorers` and `self.part_scorers`.
  227. Its values are states of the scorers.
  228. """
  229. new_states = dict()
  230. for k, v in states.items():
  231. new_states[k] = v
  232. for k, d in self.part_scorers.items():
  233. new_states[k] = d.select_state(part_states[k], part_idx)
  234. return new_states
  235. def search(
  236. self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
  237. ) -> List[Hypothesis]:
  238. """Search new tokens for running hypotheses and encoded speech x.
  239. Args:
  240. running_hyps (List[Hypothesis]): Running hypotheses on beam
  241. x (torch.Tensor): Encoded speech feature (T, D)
  242. Returns:
  243. List[Hypotheses]: Best sorted hypotheses
  244. """
  245. best_hyps = []
  246. part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
  247. for hyp in running_hyps:
  248. # scoring
  249. weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
  250. weighted_scores += am_score
  251. scores, states = self.score_full(hyp, x)
  252. for k in self.full_scorers:
  253. weighted_scores += self.weights[k] * scores[k]
  254. # partial scoring
  255. if self.do_pre_beam:
  256. pre_beam_scores = (
  257. weighted_scores
  258. if self.pre_beam_score_key == "full"
  259. else scores[self.pre_beam_score_key]
  260. )
  261. part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
  262. part_scores, part_states = self.score_partial(hyp, part_ids, x)
  263. for k in self.part_scorers:
  264. weighted_scores[part_ids] += self.weights[k] * part_scores[k]
  265. # add previous hyp score
  266. weighted_scores += hyp.score
  267. # update hyps
  268. for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
  269. # will be (2 x beam at most)
  270. best_hyps.append(
  271. Hypothesis(
  272. score=weighted_scores[j],
  273. yseq=self.append_token(hyp.yseq, j),
  274. scores=self.merge_scores(
  275. hyp.scores, scores, j, part_scores, part_j
  276. ),
  277. states=self.merge_states(states, part_states, part_j),
  278. )
  279. )
  280. # sort and prune 2 x beam -> beam
  281. best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
  282. : min(len(best_hyps), self.beam_size)
  283. ]
  284. return best_hyps
  285. def forward(
  286. self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
  287. ) -> List[Hypothesis]:
  288. """Perform beam search.
  289. Args:
  290. x (torch.Tensor): Encoded speech feature (T, D)
  291. maxlenratio (float): Input length ratio to obtain max output length.
  292. If maxlenratio=0.0 (default), it uses a end-detect function
  293. to automatically find maximum hypothesis lengths
  294. If maxlenratio<0.0, its absolute value is interpreted
  295. as a constant max output length.
  296. minlenratio (float): Input length ratio to obtain min output length.
  297. Returns:
  298. list[Hypothesis]: N-best decoding results
  299. """
  300. # set length bounds
  301. maxlen = am_scores.shape[0]
  302. logging.info("decoder input length: " + str(x.shape[0]))
  303. logging.info("max output length: " + str(maxlen))
  304. # main loop of prefix search
  305. running_hyps = self.init_hyp(x)
  306. ended_hyps = []
  307. for i in range(maxlen):
  308. logging.debug("position " + str(i))
  309. best = self.search(running_hyps, x, am_scores[i])
  310. # post process of one iteration
  311. running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
  312. # end detection
  313. if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
  314. logging.info(f"end detected at {i}")
  315. break
  316. if len(running_hyps) == 0:
  317. logging.info("no hypothesis. Finish decoding.")
  318. break
  319. else:
  320. logging.debug(f"remained hypotheses: {len(running_hyps)}")
  321. nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
  322. # check the number of hypotheses reaching to eos
  323. if len(nbest_hyps) == 0:
  324. logging.warning(
  325. "there is no N-best results, perform recognition "
  326. "again with smaller minlenratio."
  327. )
  328. return (
  329. []
  330. if minlenratio < 0.1
  331. else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
  332. )
  333. # report the best result
  334. best = nbest_hyps[0]
  335. for k, v in best.scores.items():
  336. logging.info(
  337. f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
  338. )
  339. logging.info(f"total log probability: {best.score:.2f}")
  340. logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
  341. logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
  342. if self.token_list is not None:
  343. logging.info(
  344. "best hypo: "
  345. + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]])
  346. + "\n"
  347. )
  348. return nbest_hyps
  349. def post_process(
  350. self,
  351. i: int,
  352. maxlen: int,
  353. maxlenratio: float,
  354. running_hyps: List[Hypothesis],
  355. ended_hyps: List[Hypothesis],
  356. ) -> List[Hypothesis]:
  357. """Perform post-processing of beam search iterations.
  358. Args:
  359. i (int): The length of hypothesis tokens.
  360. maxlen (int): The maximum length of tokens in beam search.
  361. maxlenratio (int): The maximum length ratio in beam search.
  362. running_hyps (List[Hypothesis]): The running hypotheses in beam search.
  363. ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
  364. Returns:
  365. List[Hypothesis]: The new running hypotheses.
  366. """
  367. logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
  368. if self.token_list is not None:
  369. logging.debug(
  370. "best hypo: "
  371. + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]])
  372. )
  373. # add eos in the final loop to avoid that there are no ended hyps
  374. if i == maxlen - 1:
  375. logging.info("adding <eos> in the last position in the loop")
  376. running_hyps = [
  377. h._replace(yseq=self.append_token(h.yseq, self.eos))
  378. for h in running_hyps
  379. ]
  380. # add ended hypotheses to a final list, and removed them from current hypotheses
  381. # (this will be a problem, number of hyps < beam)
  382. remained_hyps = []
  383. for hyp in running_hyps:
  384. if hyp.yseq[-1] == self.eos:
  385. # e.g., Word LM needs to add final <eos> score
  386. for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
  387. s = d.final_score(hyp.states[k])
  388. hyp.scores[k] += s
  389. hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
  390. ended_hyps.append(hyp)
  391. else:
  392. remained_hyps.append(hyp)
  393. return remained_hyps