CT_Transformer.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. from funasr.models.encoder.sanm_encoder import SANMEncoder
  5. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  6. from funasr.models.encoder.sanm_encoder import SANMVadEncoder
  7. from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
  8. class CT_Transformer(nn.Module):
  9. """
  10. Author: Speech Lab of DAMO Academy, Alibaba Group
  11. CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
  12. https://arxiv.org/pdf/2003.01309.pdf
  13. """
  14. def __init__(
  15. self,
  16. model,
  17. max_seq_len=512,
  18. model_name='punc_model',
  19. **kwargs,
  20. ):
  21. super().__init__()
  22. onnx = False
  23. if "onnx" in kwargs:
  24. onnx = kwargs["onnx"]
  25. self.embed = model.embed
  26. self.decoder = model.decoder
  27. # self.model = model
  28. self.feats_dim = self.embed.embedding_dim
  29. self.num_embeddings = self.embed.num_embeddings
  30. self.model_name = model_name
  31. if isinstance(model.encoder, SANMEncoder):
  32. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  33. else:
  34. assert False, "Only support samn encode."
  35. def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
  36. """Compute loss value from buffer sequences.
  37. Args:
  38. input (torch.Tensor): Input ids. (batch, len)
  39. hidden (torch.Tensor): Target ids. (batch, len)
  40. """
  41. x = self.embed(inputs)
  42. # mask = self._target_mask(input)
  43. h, _ = self.encoder(x, text_lengths)
  44. y = self.decoder(h)
  45. return y
  46. def get_dummy_inputs(self):
  47. length = 120
  48. text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
  49. text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
  50. return (text_indexes, text_lengths)
  51. def get_input_names(self):
  52. return ['inputs', 'text_lengths']
  53. def get_output_names(self):
  54. return ['logits']
  55. def get_dynamic_axes(self):
  56. return {
  57. 'inputs': {
  58. 0: 'batch_size',
  59. 1: 'feats_length'
  60. },
  61. 'text_lengths': {
  62. 0: 'batch_size',
  63. },
  64. 'logits': {
  65. 0: 'batch_size',
  66. 1: 'logits_length'
  67. },
  68. }
  69. class CT_Transformer_VadRealtime(nn.Module):
  70. """
  71. Author: Speech Lab of DAMO Academy, Alibaba Group
  72. CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
  73. https://arxiv.org/pdf/2003.01309.pdf
  74. """
  75. def __init__(
  76. self,
  77. model,
  78. max_seq_len=512,
  79. model_name='punc_model',
  80. **kwargs,
  81. ):
  82. super().__init__()
  83. onnx = False
  84. if "onnx" in kwargs:
  85. onnx = kwargs["onnx"]
  86. self.embed = model.embed
  87. if isinstance(model.encoder, SANMVadEncoder):
  88. self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
  89. else:
  90. assert False, "Only support samn encode."
  91. self.decoder = model.decoder
  92. self.model_name = model_name
  93. def forward(self, inputs: torch.Tensor,
  94. text_lengths: torch.Tensor,
  95. vad_indexes: torch.Tensor,
  96. sub_masks: torch.Tensor,
  97. ) -> Tuple[torch.Tensor, None]:
  98. """Compute loss value from buffer sequences.
  99. Args:
  100. input (torch.Tensor): Input ids. (batch, len)
  101. hidden (torch.Tensor): Target ids. (batch, len)
  102. """
  103. x = self.embed(inputs)
  104. # mask = self._target_mask(input)
  105. h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
  106. y = self.decoder(h)
  107. return y
  108. def with_vad(self):
  109. return True
  110. def get_dummy_inputs(self):
  111. length = 120
  112. text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
  113. text_lengths = torch.tensor([length], dtype=torch.int32)
  114. vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
  115. sub_masks = torch.ones(length, length, dtype=torch.float32)
  116. sub_masks = torch.tril(sub_masks).type(torch.float32)
  117. return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
  118. def get_input_names(self):
  119. return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
  120. def get_output_names(self):
  121. return ['logits']
  122. def get_dynamic_axes(self):
  123. return {
  124. 'inputs': {
  125. 1: 'feats_length'
  126. },
  127. 'vad_masks': {
  128. 2: 'feats_length1',
  129. 3: 'feats_length2'
  130. },
  131. 'sub_masks': {
  132. 2: 'feats_length1',
  133. 3: 'feats_length2'
  134. },
  135. 'logits': {
  136. 1: 'logits_length'
  137. },
  138. }