vad_realtime_transformer.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. from funasr.models.encoder.sanm_encoder import SANMVadEncoder
  5. from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
  6. class VadRealtimeTransformer(nn.Module):
  7. def __init__(
  8. self,
  9. model,
  10. max_seq_len=512,
  11. model_name='punc_model',
  12. **kwargs,
  13. ):
  14. super().__init__()
  15. onnx = False
  16. if "onnx" in kwargs:
  17. onnx = kwargs["onnx"]
  18. self.embed = model.embed
  19. if isinstance(model.encoder, SANMVadEncoder):
  20. self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
  21. else:
  22. assert False, "Only support samn encode."
  23. # self.encoder = model.encoder
  24. self.decoder = model.decoder
  25. self.model_name = model_name
  26. def forward(self, input: torch.Tensor,
  27. text_lengths: torch.Tensor,
  28. vad_indexes: torch.Tensor,
  29. sub_masks: torch.Tensor,
  30. ) -> Tuple[torch.Tensor, None]:
  31. """Compute loss value from buffer sequences.
  32. Args:
  33. input (torch.Tensor): Input ids. (batch, len)
  34. hidden (torch.Tensor): Target ids. (batch, len)
  35. """
  36. x = self.embed(input)
  37. # mask = self._target_mask(input)
  38. h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
  39. y = self.decoder(h)
  40. return y
  41. def with_vad(self):
  42. return True
  43. def get_dummy_inputs(self):
  44. length = 120
  45. text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
  46. text_lengths = torch.tensor([length], dtype=torch.int32)
  47. vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
  48. sub_masks = torch.ones(length, length, dtype=torch.float32)
  49. sub_masks = torch.tril(sub_masks).type(torch.float32)
  50. return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
  51. def get_input_names(self):
  52. return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
  53. def get_output_names(self):
  54. return ['logits']
  55. def get_dynamic_axes(self):
  56. return {
  57. 'input': {
  58. 1: 'feats_length'
  59. },
  60. 'vad_mask': {
  61. 2: 'feats_length1',
  62. 3: 'feats_length2'
  63. },
  64. 'sub_masks': {
  65. 2: 'feats_length1',
  66. 3: 'feats_length2'
  67. },
  68. 'logits': {
  69. 1: 'logits_length'
  70. },
  71. }