target_delay_transformer.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from typing import Any
  2. from typing import List
  3. from typing import Tuple
  4. import torch
  5. import torch.nn as nn
  6. from funasr.modules.embedding import PositionalEncoding
  7. from funasr.modules.embedding import SinusoidalPositionEncoder
  8. #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
  9. from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
  10. #from funasr.modules.mask import subsequent_n_mask
  11. from funasr.punctuation.abs_model import AbsPunctuation
  12. class TargetDelayTransformer(AbsPunctuation):
  13. def __init__(
  14. self,
  15. vocab_size: int,
  16. punc_size: int,
  17. pos_enc: str = None,
  18. embed_unit: int = 128,
  19. att_unit: int = 256,
  20. head: int = 2,
  21. unit: int = 1024,
  22. layer: int = 4,
  23. dropout_rate: float = 0.5,
  24. ):
  25. super().__init__()
  26. if pos_enc == "sinusoidal":
  27. # pos_enc_class = PositionalEncoding
  28. pos_enc_class = SinusoidalPositionEncoder
  29. elif pos_enc is None:
  30. def pos_enc_class(*args, **kwargs):
  31. return nn.Sequential() # indentity
  32. else:
  33. raise ValueError(f"unknown pos-enc option: {pos_enc}")
  34. self.embed = nn.Embedding(vocab_size, embed_unit)
  35. self.encoder = Encoder(
  36. input_size=embed_unit,
  37. output_size=att_unit,
  38. attention_heads=head,
  39. linear_units=unit,
  40. num_blocks=layer,
  41. dropout_rate=dropout_rate,
  42. input_layer="pe",
  43. # pos_enc_class=pos_enc_class,
  44. padding_idx=0,
  45. )
  46. self.decoder = nn.Linear(att_unit, punc_size)
  47. # def _target_mask(self, ys_in_pad):
  48. # ys_mask = ys_in_pad != 0
  49. # m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
  50. # return ys_mask.unsqueeze(-2) & m
  51. def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
  52. """Compute loss value from buffer sequences.
  53. Args:
  54. input (torch.Tensor): Input ids. (batch, len)
  55. hidden (torch.Tensor): Target ids. (batch, len)
  56. """
  57. x = self.embed(input)
  58. # mask = self._target_mask(input)
  59. h, _, _ = self.encoder(x, text_lengths)
  60. y = self.decoder(h)
  61. return y, None
  62. def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
  63. """Score new token.
  64. Args:
  65. y (torch.Tensor): 1D torch.int64 prefix tokens.
  66. state: Scorer state for prefix tokens
  67. x (torch.Tensor): encoder feature that generates ys.
  68. Returns:
  69. tuple[torch.Tensor, Any]: Tuple of
  70. torch.float32 scores for next token (vocab_size)
  71. and next state for ys
  72. """
  73. y = y.unsqueeze(0)
  74. h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
  75. h = self.decoder(h[:, -1])
  76. logp = h.log_softmax(dim=-1).squeeze(0)
  77. return logp, cache
  78. def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
  79. """Score new token batch.
  80. Args:
  81. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  82. states (List[Any]): Scorer states for prefix tokens.
  83. xs (torch.Tensor):
  84. The encoder feature that generates ys (n_batch, xlen, n_feat).
  85. Returns:
  86. tuple[torch.Tensor, List[Any]]: Tuple of
  87. batchfied scores for next token with shape of `(n_batch, vocab_size)`
  88. and next state list for ys.
  89. """
  90. # merge states
  91. n_batch = len(ys)
  92. n_layers = len(self.encoder.encoders)
  93. if states[0] is None:
  94. batch_state = None
  95. else:
  96. # transpose state of [batch, layer] into [layer, batch]
  97. batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
  98. # batch decoding
  99. h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
  100. h = self.decoder(h[:, -1])
  101. logp = h.log_softmax(dim=-1)
  102. # transpose state of [layer, batch] into [batch, layer]
  103. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  104. return logp, state_list