| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- from typing import Tuple
- import torch
- import torch.nn as nn
- class TargetDelayTransformer(nn.Module):
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- self.embed = model.embed
- self.decoder = model.decoder
- # self.model = model
- self.feats_dim = self.embed.embedding_dim
- self.num_embeddings = self.embed.num_embeddings
- self.model_name = model_name
- # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
- from funasr.models.encoder.sanm_encoder import SANMEncoder
- from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
- def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
- """
- x = self.embed(input)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths)
- y = self.decoder(h)
- return y
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
- text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
- return (text_indexes, text_lengths)
- def get_input_names(self):
- return ['input', 'text_lengths']
- def get_output_names(self):
- return ['logits']
- def get_dynamic_axes(self):
- return {
- 'input': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'text_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
|