target_delay_transformer.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import Tuple
  2. import torch
  3. import torch.nn as nn
  4. class TargetDelayTransformer(nn.Module):
  5. def __init__(
  6. self,
  7. model,
  8. max_seq_len=512,
  9. model_name='punc_model',
  10. **kwargs,
  11. ):
  12. super().__init__()
  13. onnx = False
  14. if "onnx" in kwargs:
  15. onnx = kwargs["onnx"]
  16. self.embed = model.embed
  17. self.decoder = model.decoder
  18. # self.model = model
  19. self.feats_dim = self.embed.embedding_dim
  20. self.num_embeddings = self.embed.num_embeddings
  21. self.model_name = model_name
  22. # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
  23. from funasr.models.encoder.sanm_encoder import SANMEncoder
  24. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  25. if isinstance(model.encoder, SANMEncoder):
  26. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  27. else:
  28. assert False, "Only support samn encode."
  29. def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
  30. """Compute loss value from buffer sequences.
  31. Args:
  32. input (torch.Tensor): Input ids. (batch, len)
  33. hidden (torch.Tensor): Target ids. (batch, len)
  34. """
  35. x = self.embed(input)
  36. # mask = self._target_mask(input)
  37. h, _ = self.encoder(x, text_lengths)
  38. y = self.decoder(h)
  39. return y
  40. def get_dummy_inputs(self):
  41. length = 120
  42. text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
  43. text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
  44. return (text_indexes, text_lengths)
  45. def get_input_names(self):
  46. return ['input', 'text_lengths']
  47. def get_output_names(self):
  48. return ['logits']
  49. def get_dynamic_axes(self):
  50. return {
  51. 'input': {
  52. 0: 'batch_size',
  53. 1: 'feats_length'
  54. },
  55. 'text_lengths': {
  56. 0: 'batch_size',
  57. },
  58. 'logits': {
  59. 0: 'batch_size',
  60. 1: 'logits_length'
  61. },
  62. }