|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from funasr.export.utils.torch_function import MakePadMask
|
|
|
-from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
+from funasr.export.utils.torch_function import sequence_mask
|
|
|
from funasr.models.encoder.sanm_encoder import SANMEncoder
|
|
|
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
|
|
from funasr.models.predictor.cif import CifPredictorV2
|
|
|
@@ -29,19 +29,24 @@ class Paraformer(nn.Module):
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__()
|
|
|
+ onnx = False
|
|
|
+ if "onnx" in kwargs:
|
|
|
+ onnx = kwargs["onnx"]
|
|
|
if isinstance(model.encoder, SANMEncoder):
|
|
|
- self.encoder = SANMEncoder_export(model.encoder)
|
|
|
+ self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
|
|
if isinstance(model.predictor, CifPredictorV2):
|
|
|
self.predictor = CifPredictorV2_export(model.predictor)
|
|
|
if isinstance(model.decoder, ParaformerSANMDecoder):
|
|
|
- self.decoder = ParaformerSANMDecoder_export(model.decoder)
|
|
|
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
|
|
+ self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
|
|
|
+
|
|
|
self.feats_dim = feats_dim
|
|
|
self.model_name = model_name
|
|
|
- self.onnx = False
|
|
|
- if "onnx" in kwargs:
|
|
|
- self.onnx = kwargs["onnx"]
|
|
|
-
|
|
|
+
|
|
|
+ if onnx:
|
|
|
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
|
|
+ else:
|
|
|
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
|
|
+
|
|
|
def forward(
|
|
|
self,
|
|
|
speech: torch.Tensor,
|
|
|
@@ -66,7 +71,7 @@ class Paraformer(nn.Module):
|
|
|
|
|
|
def get_dummy_inputs(self):
|
|
|
speech = torch.randn(2, 30, self.feats_dim)
|
|
|
- speech_lengths = torch.tensor([6, 30]).long()
|
|
|
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
|
|
return (speech, speech_lengths)
|
|
|
|
|
|
def get_input_names(self):
|