| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- """Parallel beam search module."""
- import logging
- from typing import Any
- from typing import Dict
- from typing import List
- from typing import NamedTuple
- from typing import Tuple
- import torch
- from torch.nn.utils.rnn import pad_sequence
- from funasr.modules.beam_search.beam_search import BeamSearch
- from funasr.modules.beam_search.beam_search import Hypothesis
- class BatchHypothesis(NamedTuple):
- """Batchfied/Vectorized hypothesis data type."""
- yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
- score: torch.Tensor = torch.tensor([]) # (batch,)
- length: torch.Tensor = torch.tensor([]) # (batch,)
- scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
- states: Dict[str, Dict] = dict()
- def __len__(self) -> int:
- """Return a batch size."""
- return len(self.length)
- class BatchBeamSearch(BeamSearch):
- """Batch beam search implementation."""
- def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
- """Convert list to batch."""
- if len(hyps) == 0:
- return BatchHypothesis()
- return BatchHypothesis(
- yseq=pad_sequence(
- [h.yseq for h in hyps], batch_first=True, padding_value=self.eos
- ),
- length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
- score=torch.tensor([h.score for h in hyps]),
- scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
- states={k: [h.states[k] for h in hyps] for k in self.scorers},
- )
- def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
- return BatchHypothesis(
- yseq=hyps.yseq[ids],
- score=hyps.score[ids],
- length=hyps.length[ids],
- scores={k: v[ids] for k, v in hyps.scores.items()},
- states={
- k: [self.scorers[k].select_state(v, i) for i in ids]
- for k, v in hyps.states.items()
- },
- )
- def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
- return Hypothesis(
- yseq=hyps.yseq[i, : hyps.length[i]],
- score=hyps.score[i],
- scores={k: v[i] for k, v in hyps.scores.items()},
- states={
- k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
- },
- )
- def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
- """Revert batch to list."""
- return [
- Hypothesis(
- yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
- score=batch_hyps.score[i],
- scores={k: batch_hyps.scores[k][i] for k in self.scorers},
- states={
- k: v.select_state(batch_hyps.states[k], i)
- for k, v in self.scorers.items()
- },
- )
- for i in range(len(batch_hyps.length))
- ]
- def batch_beam(
- self, weighted_scores: torch.Tensor, ids: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Batch-compute topk full token ids and partial token ids.
- Args:
- weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
- Its shape is `(n_beam, self.vocab_size)`.
- ids (torch.Tensor): The partial token ids to compute topk.
- Its shape is `(n_beam, self.pre_beam_size)`.
- Returns:
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- The topk full (prev_hyp, new_token) ids
- and partial (prev_hyp, new_token) ids.
- Their shapes are all `(self.beam_size,)`
- """
- top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
- # Because of the flatten above, `top_ids` is organized as:
- # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
- # where V is `self.n_vocab` and K is `self.beam_size`
- prev_hyp_ids = top_ids // self.n_vocab
- new_token_ids = top_ids % self.n_vocab
- return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
- def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
- """Get an initial hypothesis data.
- Args:
- x (torch.Tensor): The encoder output feature
- Returns:
- Hypothesis: The initial hypothesis.
- """
- init_states = dict()
- init_scores = dict()
- for k, d in self.scorers.items():
- init_states[k] = d.batch_init_state(x)
- init_scores[k] = 0.0
- return self.batchfy(
- [
- Hypothesis(
- score=0.0,
- scores=init_scores,
- states=init_states,
- yseq=torch.tensor([self.sos], device=x.device),
- )
- ]
- )
- def score_full(
- self, hyp: BatchHypothesis, x: torch.Tensor
- ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
- """Score new hypothesis by `self.full_scorers`.
- Args:
- hyp (Hypothesis): Hypothesis with prefix tokens to score
- x (torch.Tensor): Corresponding input feature
- Returns:
- Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
- score dict of `hyp` that has string keys of `self.full_scorers`
- and tensor score values of shape: `(self.n_vocab,)`,
- and state dict that has string keys
- and state values of `self.full_scorers`
- """
- scores = dict()
- states = dict()
- for k, d in self.full_scorers.items():
- scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
- return scores, states
- def score_partial(
- self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
- ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
- """Score new hypothesis by `self.full_scorers`.
- Args:
- hyp (Hypothesis): Hypothesis with prefix tokens to score
- ids (torch.Tensor): 2D tensor of new partial tokens to score
- x (torch.Tensor): Corresponding input feature
- Returns:
- Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
- score dict of `hyp` that has string keys of `self.full_scorers`
- and tensor score values of shape: `(self.n_vocab,)`,
- and state dict that has string keys
- and state values of `self.full_scorers`
- """
- scores = dict()
- states = dict()
- for k, d in self.part_scorers.items():
- scores[k], states[k] = d.batch_score_partial(
- hyp.yseq, ids, hyp.states[k], x
- )
- return scores, states
- def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
- """Merge states for new hypothesis.
- Args:
- states: states of `self.full_scorers`
- part_states: states of `self.part_scorers`
- part_idx (int): The new token id for `part_scores`
- Returns:
- Dict[str, torch.Tensor]: The new score dict.
- Its keys are names of `self.full_scorers` and `self.part_scorers`.
- Its values are states of the scorers.
- """
- new_states = dict()
- for k, v in states.items():
- new_states[k] = v
- for k, v in part_states.items():
- new_states[k] = v
- return new_states
- def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
- """Search new tokens for running hypotheses and encoded speech x.
- Args:
- running_hyps (BatchHypothesis): Running hypotheses on beam
- x (torch.Tensor): Encoded speech feature (T, D)
- Returns:
- BatchHypothesis: Best sorted hypotheses
- """
- n_batch = len(running_hyps)
- part_ids = None # no pre-beam
- # batch scoring
- weighted_scores = torch.zeros(
- n_batch, self.n_vocab, dtype=x.dtype, device=x.device
- )
- scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
- for k in self.full_scorers:
- weighted_scores += self.weights[k] * scores[k]
- # partial scoring
- if self.do_pre_beam:
- pre_beam_scores = (
- weighted_scores
- if self.pre_beam_score_key == "full"
- else scores[self.pre_beam_score_key]
- )
- part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
- # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
- # full-size score matrices, which has non-zero scores for part_ids and zeros
- # for others.
- part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
- for k in self.part_scorers:
- weighted_scores += self.weights[k] * part_scores[k]
- # add previous hyp scores
- weighted_scores += running_hyps.score.to(
- dtype=x.dtype, device=x.device
- ).unsqueeze(1)
- # TODO(karita): do not use list. use batch instead
- # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
- # update hyps
- best_hyps = []
- prev_hyps = self.unbatchfy(running_hyps)
- for (
- full_prev_hyp_id,
- full_new_token_id,
- part_prev_hyp_id,
- part_new_token_id,
- ) in zip(*self.batch_beam(weighted_scores, part_ids)):
- prev_hyp = prev_hyps[full_prev_hyp_id]
- best_hyps.append(
- Hypothesis(
- score=weighted_scores[full_prev_hyp_id, full_new_token_id],
- yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
- scores=self.merge_scores(
- prev_hyp.scores,
- {k: v[full_prev_hyp_id] for k, v in scores.items()},
- full_new_token_id,
- {k: v[part_prev_hyp_id] for k, v in part_scores.items()},
- part_new_token_id,
- ),
- states=self.merge_states(
- {
- k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
- for k, v in states.items()
- },
- {
- k: self.part_scorers[k].select_state(
- v, part_prev_hyp_id, part_new_token_id
- )
- for k, v in part_states.items()
- },
- part_new_token_id,
- ),
- )
- )
- return self.batchfy(best_hyps)
- def post_process(
- self,
- i: int,
- maxlen: int,
- maxlenratio: float,
- running_hyps: BatchHypothesis,
- ended_hyps: List[Hypothesis],
- ) -> BatchHypothesis:
- """Perform post-processing of beam search iterations.
- Args:
- i (int): The length of hypothesis tokens.
- maxlen (int): The maximum length of tokens in beam search.
- maxlenratio (int): The maximum length ratio in beam search.
- running_hyps (BatchHypothesis): The running hypotheses in beam search.
- ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
- Returns:
- BatchHypothesis: The new running hypotheses.
- """
- n_batch = running_hyps.yseq.shape[0]
- logging.debug(f"the number of running hypothes: {n_batch}")
- if self.token_list is not None:
- logging.debug(
- "best hypo: "
- + "".join(
- [
- self.token_list[x]
- for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
- ]
- )
- )
- # add eos in the final loop to avoid that there are no ended hyps
- if i == maxlen - 1:
- logging.info("adding <eos> in the last position in the loop")
- yseq_eos = torch.cat(
- (
- running_hyps.yseq,
- torch.full(
- (n_batch, 1),
- self.eos,
- device=running_hyps.yseq.device,
- dtype=torch.int64,
- ),
- ),
- 1,
- )
- running_hyps.yseq.resize_as_(yseq_eos)
- running_hyps.yseq[:] = yseq_eos
- running_hyps.length[:] = yseq_eos.shape[1]
- # add ended hypotheses to a final list, and removed them from current hypotheses
- # (this will be a probmlem, number of hyps < beam)
- is_eos = (
- running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
- == self.eos
- )
- for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
- hyp = self._select(running_hyps, b)
- ended_hyps.append(hyp)
- remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
- return self._batch_select(running_hyps, remained_ids)
|