transformer_lm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import Any
  2. from typing import List
  3. from typing import Tuple
  4. import torch
  5. import torch.nn as nn
  6. from funasr.modules.embedding import PositionalEncoding
  7. from funasr.models.encoder.transformer_encoder import TransformerEncoder_s0 as Encoder
  8. from funasr.modules.mask import subsequent_mask
  9. from funasr.lm.abs_model import AbsLM
  10. class TransformerLM(AbsLM):
  11. def __init__(
  12. self,
  13. vocab_size: int,
  14. pos_enc: str = None,
  15. embed_unit: int = 128,
  16. att_unit: int = 256,
  17. head: int = 2,
  18. unit: int = 1024,
  19. layer: int = 4,
  20. dropout_rate: float = 0.5,
  21. ):
  22. super().__init__()
  23. if pos_enc == "sinusoidal":
  24. pos_enc_class = PositionalEncoding
  25. elif pos_enc is None:
  26. def pos_enc_class(*args, **kwargs):
  27. return nn.Sequential() # indentity
  28. else:
  29. raise ValueError(f"unknown pos-enc option: {pos_enc}")
  30. self.embed = nn.Embedding(vocab_size, embed_unit)
  31. self.encoder = Encoder(
  32. idim=embed_unit,
  33. attention_dim=att_unit,
  34. attention_heads=head,
  35. linear_units=unit,
  36. num_blocks=layer,
  37. dropout_rate=dropout_rate,
  38. input_layer="linear",
  39. pos_enc_class=pos_enc_class,
  40. )
  41. self.decoder = nn.Linear(att_unit, vocab_size)
  42. def _target_mask(self, ys_in_pad):
  43. ys_mask = ys_in_pad != 0
  44. m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
  45. return ys_mask.unsqueeze(-2) & m
  46. def forward(self, input: torch.Tensor, hidden: None) -> Tuple[torch.Tensor, None]:
  47. """Compute LM loss value from buffer sequences.
  48. Args:
  49. input (torch.Tensor): Input ids. (batch, len)
  50. hidden (torch.Tensor): Target ids. (batch, len)
  51. """
  52. x = self.embed(input)
  53. mask = self._target_mask(input)
  54. h, _ = self.encoder(x, mask)
  55. y = self.decoder(h)
  56. return y, None
  57. def score(
  58. self, y: torch.Tensor, state: Any, x: torch.Tensor
  59. ) -> Tuple[torch.Tensor, Any]:
  60. """Score new token.
  61. Args:
  62. y (torch.Tensor): 1D torch.int64 prefix tokens.
  63. state: Scorer state for prefix tokens
  64. x (torch.Tensor): encoder feature that generates ys.
  65. Returns:
  66. tuple[torch.Tensor, Any]: Tuple of
  67. torch.float32 scores for next token (vocab_size)
  68. and next state for ys
  69. """
  70. y = y.unsqueeze(0)
  71. h, _, cache = self.encoder.forward_one_step(
  72. self.embed(y), self._target_mask(y), cache=state
  73. )
  74. h = self.decoder(h[:, -1])
  75. logp = h.log_softmax(dim=-1).squeeze(0)
  76. return logp, cache
  77. def batch_score(
  78. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  79. ) -> Tuple[torch.Tensor, List[Any]]:
  80. """Score new token batch.
  81. Args:
  82. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  83. states (List[Any]): Scorer states for prefix tokens.
  84. xs (torch.Tensor):
  85. The encoder feature that generates ys (n_batch, xlen, n_feat).
  86. Returns:
  87. tuple[torch.Tensor, List[Any]]: Tuple of
  88. batchfied scores for next token with shape of `(n_batch, vocab_size)`
  89. and next state list for ys.
  90. """
  91. # merge states
  92. n_batch = len(ys)
  93. n_layers = len(self.encoder.encoders)
  94. if states[0] is None:
  95. batch_state = None
  96. else:
  97. # transpose state of [batch, layer] into [layer, batch]
  98. batch_state = [
  99. torch.stack([states[b][i] for b in range(n_batch)])
  100. for i in range(n_layers)
  101. ]
  102. # batch decoding
  103. h, _, states = self.encoder.forward_one_step(
  104. self.embed(ys), self._target_mask(ys), cache=batch_state
  105. )
  106. h = self.decoder(h[:, -1])
  107. logp = h.log_softmax(dim=-1)
  108. # transpose state of [layer, batch] into [batch, layer]
  109. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  110. return logp, state_list