ctc.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """ScorerInterface implementation for CTC."""
  2. import numpy as np
  3. import torch
  4. from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScore
  5. from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScoreTH
  6. from funasr.modules.scorers.scorer_interface import BatchPartialScorerInterface
  7. class CTCPrefixScorer(BatchPartialScorerInterface):
  8. """Decoder interface wrapper for CTCPrefixScore."""
  9. def __init__(self, ctc: torch.nn.Module, eos: int):
  10. """Initialize class.
  11. Args:
  12. ctc (torch.nn.Module): The CTC implementation.
  13. For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
  14. eos (int): The end-of-sequence id.
  15. """
  16. self.ctc = ctc
  17. self.eos = eos
  18. self.impl = None
  19. def init_state(self, x: torch.Tensor):
  20. """Get an initial state for decoding.
  21. Args:
  22. x (torch.Tensor): The encoded feature tensor
  23. Returns: initial state
  24. """
  25. logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
  26. # TODO(karita): use CTCPrefixScoreTH
  27. self.impl = CTCPrefixScore(logp, 0, self.eos, np)
  28. return 0, self.impl.initial_state()
  29. def select_state(self, state, i, new_id=None):
  30. """Select state with relative ids in the main beam search.
  31. Args:
  32. state: Decoder state for prefix tokens
  33. i (int): Index to select a state in the main beam search
  34. new_id (int): New label id to select a state if necessary
  35. Returns:
  36. state: pruned state
  37. """
  38. if type(state) == tuple:
  39. if len(state) == 2: # for CTCPrefixScore
  40. sc, st = state
  41. return sc[i], st[i]
  42. else: # for CTCPrefixScoreTH (need new_id > 0)
  43. r, log_psi, f_min, f_max, scoring_idmap = state
  44. s = log_psi[i, new_id].expand(log_psi.size(1))
  45. if scoring_idmap is not None:
  46. return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
  47. else:
  48. return r[:, :, i, new_id], s, f_min, f_max
  49. return None if state is None else state[i]
  50. def score_partial(self, y, ids, state, x):
  51. """Score new token.
  52. Args:
  53. y (torch.Tensor): 1D prefix token
  54. next_tokens (torch.Tensor): torch.int64 next token to score
  55. state: decoder state for prefix tokens
  56. x (torch.Tensor): 2D encoder feature that generates ys
  57. Returns:
  58. tuple[torch.Tensor, Any]:
  59. Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
  60. and next state for ys
  61. """
  62. prev_score, state = state
  63. presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
  64. tscore = torch.as_tensor(
  65. presub_score - prev_score, device=x.device, dtype=x.dtype
  66. )
  67. return tscore, (presub_score, new_st)
  68. def batch_init_state(self, x: torch.Tensor):
  69. """Get an initial state for decoding.
  70. Args:
  71. x (torch.Tensor): The encoded feature tensor
  72. Returns: initial state
  73. """
  74. logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
  75. xlen = torch.tensor([logp.size(1)])
  76. self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
  77. return None
  78. def batch_score_partial(self, y, ids, state, x):
  79. """Score new token.
  80. Args:
  81. y (torch.Tensor): 1D prefix token
  82. ids (torch.Tensor): torch.int64 next token to score
  83. state: decoder state for prefix tokens
  84. x (torch.Tensor): 2D encoder feature that generates ys
  85. Returns:
  86. tuple[torch.Tensor, Any]:
  87. Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
  88. and next state for ys
  89. """
  90. batch_state = (
  91. (
  92. torch.stack([s[0] for s in state], dim=2),
  93. torch.stack([s[1] for s in state]),
  94. state[0][2],
  95. state[0][3],
  96. )
  97. if state[0] is not None
  98. else None
  99. )
  100. return self.impl(y, batch_state, ids)
  101. def extend_prob(self, x: torch.Tensor):
  102. """Extend probs for decoding.
  103. This extension is for streaming decoding
  104. as in Eq (14) in https://arxiv.org/abs/2006.14941
  105. Args:
  106. x (torch.Tensor): The encoded feature tensor
  107. """
  108. logp = self.ctc.log_softmax(x.unsqueeze(0))
  109. self.impl.extend_prob(logp)
  110. def extend_state(self, state):
  111. """Extend state for decoding.
  112. This extension is for streaming decoding
  113. as in Eq (14) in https://arxiv.org/abs/2006.14941
  114. Args:
  115. state: The states of hyps
  116. Returns: exteded state
  117. """
  118. new_state = []
  119. for s in state:
  120. new_state.append(self.impl.extend_state(s))
  121. return new_state