| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- """Length bonus module."""
- from typing import Any
- from typing import List
- from typing import Tuple
- import torch
- from funasr.modules.scorers.scorer_interface import BatchScorerInterface
- class LengthBonus(BatchScorerInterface):
- """Length bonus in beam search."""
- def __init__(self, n_vocab: int):
- """Initialize class.
- Args:
- n_vocab (int): The number of tokens in vocabulary for beam search
- """
- self.n = n_vocab
- def score(self, y, state, x):
- """Score new token.
- Args:
- y (torch.Tensor): 1D torch.int64 prefix tokens.
- state: Scorer state for prefix tokens
- x (torch.Tensor): 2D encoder feature that generates ys.
- Returns:
- tuple[torch.Tensor, Any]: Tuple of
- torch.float32 scores for next token (n_vocab)
- and None
- """
- return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
- def batch_score(
- self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
- ) -> Tuple[torch.Tensor, List[Any]]:
- """Score new token batch.
- Args:
- ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
- states (List[Any]): Scorer states for prefix tokens.
- xs (torch.Tensor):
- The encoder feature that generates ys (n_batch, xlen, n_feat).
- Returns:
- tuple[torch.Tensor, List[Any]]: Tuple of
- batchfied scores for next token with shape of `(n_batch, n_vocab)`
- and next state list for ys.
- """
- return (
- torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
- ys.shape[0], self.n
- ),
- None,
- )
|