length_bonus.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """Length bonus module."""
  2. from typing import Any
  3. from typing import List
  4. from typing import Tuple
  5. import torch
  6. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  7. class LengthBonus(BatchScorerInterface):
  8. """Length bonus in beam search."""
  9. def __init__(self, n_vocab: int):
  10. """Initialize class.
  11. Args:
  12. n_vocab (int): The number of tokens in vocabulary for beam search
  13. """
  14. self.n = n_vocab
  15. def score(self, y, state, x):
  16. """Score new token.
  17. Args:
  18. y (torch.Tensor): 1D torch.int64 prefix tokens.
  19. state: Scorer state for prefix tokens
  20. x (torch.Tensor): 2D encoder feature that generates ys.
  21. Returns:
  22. tuple[torch.Tensor, Any]: Tuple of
  23. torch.float32 scores for next token (n_vocab)
  24. and None
  25. """
  26. return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
  27. def batch_score(
  28. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  29. ) -> Tuple[torch.Tensor, List[Any]]:
  30. """Score new token batch.
  31. Args:
  32. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  33. states (List[Any]): Scorer states for prefix tokens.
  34. xs (torch.Tensor):
  35. The encoder feature that generates ys (n_batch, xlen, n_feat).
  36. Returns:
  37. tuple[torch.Tensor, List[Any]]: Tuple of
  38. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  39. and next state list for ys.
  40. """
  41. return (
  42. torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
  43. ys.shape[0], self.n
  44. ),
  45. None,
  46. )