scorer_interface.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """Scorer interface module."""
  2. from typing import Any
  3. from typing import List
  4. from typing import Tuple
  5. import torch
  6. import warnings
  7. class ScorerInterface:
  8. """Scorer interface for beam search.
  9. The scorer performs scoring of the all tokens in vocabulary.
  10. Examples:
  11. * Search heuristics
  12. * :class:`espnet.nets.scorers.length_bonus.LengthBonus`
  13. * Decoder networks of the sequence-to-sequence models
  14. * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
  15. * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
  16. * Neural language models
  17. * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
  18. * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
  19. * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
  20. """
  21. def init_state(self, x: torch.Tensor) -> Any:
  22. """Get an initial state for decoding (optional).
  23. Args:
  24. x (torch.Tensor): The encoded feature tensor
  25. Returns: initial state
  26. """
  27. return None
  28. def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
  29. """Select state with relative ids in the main beam search.
  30. Args:
  31. state: Decoder state for prefix tokens
  32. i (int): Index to select a state in the main beam search
  33. new_id (int): New label index to select a state if necessary
  34. Returns:
  35. state: pruned state
  36. """
  37. return None if state is None else state[i]
  38. def score(
  39. self, y: torch.Tensor, state: Any, x: torch.Tensor
  40. ) -> Tuple[torch.Tensor, Any]:
  41. """Score new token (required).
  42. Args:
  43. y (torch.Tensor): 1D torch.int64 prefix tokens.
  44. state: Scorer state for prefix tokens
  45. x (torch.Tensor): The encoder feature that generates ys.
  46. Returns:
  47. tuple[torch.Tensor, Any]: Tuple of
  48. scores for next token that has a shape of `(n_vocab)`
  49. and next state for ys
  50. """
  51. raise NotImplementedError
  52. def final_score(self, state: Any) -> float:
  53. """Score eos (optional).
  54. Args:
  55. state: Scorer state for prefix tokens
  56. Returns:
  57. float: final score
  58. """
  59. return 0.0
  60. class BatchScorerInterface(ScorerInterface):
  61. """Batch scorer interface."""
  62. def batch_init_state(self, x: torch.Tensor) -> Any:
  63. """Get an initial state for decoding (optional).
  64. Args:
  65. x (torch.Tensor): The encoded feature tensor
  66. Returns: initial state
  67. """
  68. return self.init_state(x)
  69. def batch_score(
  70. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  71. ) -> Tuple[torch.Tensor, List[Any]]:
  72. """Score new token batch (required).
  73. Args:
  74. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  75. states (List[Any]): Scorer states for prefix tokens.
  76. xs (torch.Tensor):
  77. The encoder feature that generates ys (n_batch, xlen, n_feat).
  78. Returns:
  79. tuple[torch.Tensor, List[Any]]: Tuple of
  80. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  81. and next state list for ys.
  82. """
  83. warnings.warn(
  84. "{} batch score is implemented through for loop not parallelized".format(
  85. self.__class__.__name__
  86. )
  87. )
  88. scores = list()
  89. outstates = list()
  90. for i, (y, state, x) in enumerate(zip(ys, states, xs)):
  91. score, outstate = self.score(y, state, x)
  92. outstates.append(outstate)
  93. scores.append(score)
  94. scores = torch.cat(scores, 0).view(ys.shape[0], -1)
  95. return scores, outstates
  96. class PartialScorerInterface(ScorerInterface):
  97. """Partial scorer interface for beam search.
  98. The partial scorer performs scoring when non-partial scorer finished scoring,
  99. and receives pre-pruned next tokens to score because it is too heavy to score
  100. all the tokens.
  101. Examples:
  102. * Prefix search for connectionist-temporal-classification models
  103. * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
  104. """
  105. def score_partial(
  106. self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
  107. ) -> Tuple[torch.Tensor, Any]:
  108. """Score new token (required).
  109. Args:
  110. y (torch.Tensor): 1D prefix token
  111. next_tokens (torch.Tensor): torch.int64 next token to score
  112. state: decoder state for prefix tokens
  113. x (torch.Tensor): The encoder feature that generates ys
  114. Returns:
  115. tuple[torch.Tensor, Any]:
  116. Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
  117. and next state for ys
  118. """
  119. raise NotImplementedError
  120. class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
  121. """Batch partial scorer interface for beam search."""
  122. def batch_score_partial(
  123. self,
  124. ys: torch.Tensor,
  125. next_tokens: torch.Tensor,
  126. states: List[Any],
  127. xs: torch.Tensor,
  128. ) -> Tuple[torch.Tensor, Any]:
  129. """Score new token (required).
  130. Args:
  131. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  132. next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
  133. states (List[Any]): Scorer states for prefix tokens.
  134. xs (torch.Tensor):
  135. The encoder feature that generates ys (n_batch, xlen, n_feat).
  136. Returns:
  137. tuple[torch.Tensor, Any]:
  138. Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
  139. and next states for ys
  140. """
  141. raise NotImplementedError