target_delay_transformer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. from funasr.models.encoder.sanm_encoder import SANMEncoder
  5. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  6. from funasr.models.encoder.sanm_encoder import SANMVadEncoder
  7. from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
  8. class CT_Transformer(nn.Module):
  9. def __init__(
  10. self,
  11. model,
  12. max_seq_len=512,
  13. model_name='punc_model',
  14. **kwargs,
  15. ):
  16. super().__init__()
  17. onnx = False
  18. if "onnx" in kwargs:
  19. onnx = kwargs["onnx"]
  20. self.embed = model.embed
  21. self.decoder = model.decoder
  22. # self.model = model
  23. self.feats_dim = self.embed.embedding_dim
  24. self.num_embeddings = self.embed.num_embeddings
  25. self.model_name = model_name
  26. if isinstance(model.encoder, SANMEncoder):
  27. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  28. else:
  29. assert False, "Only support samn encode."
  30. def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
  31. """Compute loss value from buffer sequences.
  32. Args:
  33. input (torch.Tensor): Input ids. (batch, len)
  34. hidden (torch.Tensor): Target ids. (batch, len)
  35. """
  36. x = self.embed(inputs)
  37. # mask = self._target_mask(input)
  38. h, _ = self.encoder(x, text_lengths)
  39. y = self.decoder(h)
  40. return y
  41. def get_dummy_inputs(self):
  42. length = 120
  43. text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
  44. text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
  45. return (text_indexes, text_lengths)
  46. def get_input_names(self):
  47. return ['inputs', 'text_lengths']
  48. def get_output_names(self):
  49. return ['logits']
  50. def get_dynamic_axes(self):
  51. return {
  52. 'inputs': {
  53. 0: 'batch_size',
  54. 1: 'feats_length'
  55. },
  56. 'text_lengths': {
  57. 0: 'batch_size',
  58. },
  59. 'logits': {
  60. 0: 'batch_size',
  61. 1: 'logits_length'
  62. },
  63. }
  64. class CT_Transformer_VadRealtime(nn.Module):
  65. def __init__(
  66. self,
  67. model,
  68. max_seq_len=512,
  69. model_name='punc_model',
  70. **kwargs,
  71. ):
  72. super().__init__()
  73. onnx = False
  74. if "onnx" in kwargs:
  75. onnx = kwargs["onnx"]
  76. self.embed = model.embed
  77. if isinstance(model.encoder, SANMVadEncoder):
  78. self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
  79. else:
  80. assert False, "Only support samn encode."
  81. self.decoder = model.decoder
  82. self.model_name = model_name
  83. def forward(self, inputs: torch.Tensor,
  84. text_lengths: torch.Tensor,
  85. vad_indexes: torch.Tensor,
  86. sub_masks: torch.Tensor,
  87. ) -> Tuple[torch.Tensor, None]:
  88. """Compute loss value from buffer sequences.
  89. Args:
  90. input (torch.Tensor): Input ids. (batch, len)
  91. hidden (torch.Tensor): Target ids. (batch, len)
  92. """
  93. x = self.embed(inputs)
  94. # mask = self._target_mask(input)
  95. h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
  96. y = self.decoder(h)
  97. return y
  98. def with_vad(self):
  99. return True
  100. def get_dummy_inputs(self):
  101. length = 120
  102. text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
  103. text_lengths = torch.tensor([length], dtype=torch.int32)
  104. vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
  105. sub_masks = torch.ones(length, length, dtype=torch.float32)
  106. sub_masks = torch.tril(sub_masks).type(torch.float32)
  107. return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
  108. def get_input_names(self):
  109. return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
  110. def get_output_names(self):
  111. return ['logits']
  112. def get_dynamic_axes(self):
  113. return {
  114. 'inputs': {
  115. 1: 'feats_length'
  116. },
  117. 'vad_masks': {
  118. 2: 'feats_length1',
  119. 3: 'feats_length2'
  120. },
  121. 'sub_masks': {
  122. 2: 'feats_length1',
  123. 3: 'feats_length2'
  124. },
  125. 'logits': {
  126. 1: 'logits_length'
  127. },
  128. }