| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- from typing import Tuple
- import torch
- import torch.nn as nn
- from funasr.models.encoder.sanm_encoder import SANMEncoder
- from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
- from funasr.models.encoder.sanm_encoder import SANMVadEncoder
- from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
- class CT_Transformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
- https://arxiv.org/pdf/2003.01309.pdf
- """
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- self.embed = model.embed
- self.decoder = model.decoder
- # self.model = model
- self.feats_dim = self.embed.embedding_dim
- self.num_embeddings = self.embed.num_embeddings
- self.model_name = model_name
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
- def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
- """
- x = self.embed(inputs)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths)
- y = self.decoder(h)
- return y
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
- text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
- return (text_indexes, text_lengths)
- def get_input_names(self):
- return ['inputs', 'text_lengths']
- def get_output_names(self):
- return ['logits']
- def get_dynamic_axes(self):
- return {
- 'inputs': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'text_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
- class CT_Transformer_VadRealtime(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
- https://arxiv.org/pdf/2003.01309.pdf
- """
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- self.embed = model.embed
- if isinstance(model.encoder, SANMVadEncoder):
- self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
- self.decoder = model.decoder
- self.model_name = model_name
- def forward(self, inputs: torch.Tensor,
- text_lengths: torch.Tensor,
- vad_indexes: torch.Tensor,
- sub_masks: torch.Tensor,
- ) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
- """
- x = self.embed(inputs)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
- y = self.decoder(h)
- return y
- def with_vad(self):
- return True
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
- text_lengths = torch.tensor([length], dtype=torch.int32)
- vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
- sub_masks = torch.ones(length, length, dtype=torch.float32)
- sub_masks = torch.tril(sub_masks).type(torch.float32)
- return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
- def get_input_names(self):
- return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
- def get_output_names(self):
- return ['logits']
- def get_dynamic_axes(self):
- return {
- 'inputs': {
- 1: 'feats_length'
- },
- 'vad_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'sub_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'logits': {
- 1: 'logits_length'
- },
- }
|