vad_realtime_transformer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from typing import Any
  2. from typing import List
  3. from typing import Tuple
  4. import torch
  5. import torch.nn as nn
  6. from funasr.modules.embedding import SinusoidalPositionEncoder
  7. from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
  8. from funasr.train.abs_model import AbsPunctuation
  9. class VadRealtimeTransformer(AbsPunctuation):
  10. """
  11. Author: Speech Lab of DAMO Academy, Alibaba Group
  12. CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
  13. https://arxiv.org/pdf/2003.01309.pdf
  14. """
  15. def __init__(
  16. self,
  17. vocab_size: int,
  18. punc_size: int,
  19. pos_enc: str = None,
  20. embed_unit: int = 128,
  21. att_unit: int = 256,
  22. head: int = 2,
  23. unit: int = 1024,
  24. layer: int = 4,
  25. dropout_rate: float = 0.5,
  26. kernel_size: int = 11,
  27. sanm_shfit: int = 0,
  28. ):
  29. super().__init__()
  30. if pos_enc == "sinusoidal":
  31. # pos_enc_class = PositionalEncoding
  32. pos_enc_class = SinusoidalPositionEncoder
  33. elif pos_enc is None:
  34. def pos_enc_class(*args, **kwargs):
  35. return nn.Sequential() # indentity
  36. else:
  37. raise ValueError(f"unknown pos-enc option: {pos_enc}")
  38. self.embed = nn.Embedding(vocab_size, embed_unit)
  39. self.encoder = Encoder(
  40. input_size=embed_unit,
  41. output_size=att_unit,
  42. attention_heads=head,
  43. linear_units=unit,
  44. num_blocks=layer,
  45. dropout_rate=dropout_rate,
  46. input_layer="pe",
  47. # pos_enc_class=pos_enc_class,
  48. padding_idx=0,
  49. kernel_size=kernel_size,
  50. sanm_shfit=sanm_shfit,
  51. )
  52. self.decoder = nn.Linear(att_unit, punc_size)
  53. # def _target_mask(self, ys_in_pad):
  54. # ys_mask = ys_in_pad != 0
  55. # m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
  56. # return ys_mask.unsqueeze(-2) & m
  57. def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
  58. vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
  59. """Compute loss value from buffer sequences.
  60. Args:
  61. input (torch.Tensor): Input ids. (batch, len)
  62. hidden (torch.Tensor): Target ids. (batch, len)
  63. """
  64. x = self.embed(input)
  65. # mask = self._target_mask(input)
  66. h, _, _ = self.encoder(x, text_lengths, vad_indexes)
  67. y = self.decoder(h)
  68. return y, None
  69. def with_vad(self):
  70. return True
  71. def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
  72. """Score new token.
  73. Args:
  74. y (torch.Tensor): 1D torch.int64 prefix tokens.
  75. state: Scorer state for prefix tokens
  76. x (torch.Tensor): encoder feature that generates ys.
  77. Returns:
  78. tuple[torch.Tensor, Any]: Tuple of
  79. torch.float32 scores for next token (vocab_size)
  80. and next state for ys
  81. """
  82. y = y.unsqueeze(0)
  83. h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
  84. h = self.decoder(h[:, -1])
  85. logp = h.log_softmax(dim=-1).squeeze(0)
  86. return logp, cache
  87. def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
  88. """Score new token batch.
  89. Args:
  90. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  91. states (List[Any]): Scorer states for prefix tokens.
  92. xs (torch.Tensor):
  93. The encoder feature that generates ys (n_batch, xlen, n_feat).
  94. Returns:
  95. tuple[torch.Tensor, List[Any]]: Tuple of
  96. batchfied scores for next token with shape of `(n_batch, vocab_size)`
  97. and next state list for ys.
  98. """
  99. # merge states
  100. n_batch = len(ys)
  101. n_layers = len(self.encoder.encoders)
  102. if states[0] is None:
  103. batch_state = None
  104. else:
  105. # transpose state of [batch, layer] into [layer, batch]
  106. batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
  107. # batch decoding
  108. h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
  109. h = self.decoder(h[:, -1])
  110. logp = h.log_softmax(dim=-1)
  111. # transpose state of [layer, batch] into [batch, layer]
  112. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  113. return logp, state_list