e2e_vad.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from enum import Enum
  2. from typing import List, Tuple, Dict, Any
  3. import torch
  4. from torch import nn
  5. import math
  6. from funasr.models.encoder.fsmn_encoder import FSMN
  7. from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
  8. class E2EVadModel(nn.Module):
  9. def __init__(self, model,
  10. max_seq_len=512,
  11. feats_dim=400,
  12. model_name='model',
  13. **kwargs,):
  14. super(E2EVadModel, self).__init__()
  15. self.feats_dim = feats_dim
  16. self.max_seq_len = max_seq_len
  17. self.model_name = model_name
  18. if isinstance(model.encoder, FSMN):
  19. self.encoder = FSMN_export(model.encoder)
  20. else:
  21. raise "unsupported encoder"
  22. def forward(self, feats: torch.Tensor, *args, ):
  23. scores, out_caches = self.encoder(feats, *args)
  24. return scores, out_caches
  25. def get_dummy_inputs(self, frame=30):
  26. speech = torch.randn(1, frame, self.feats_dim)
  27. in_cache0 = torch.randn(1, 128, 19, 1)
  28. in_cache1 = torch.randn(1, 128, 19, 1)
  29. in_cache2 = torch.randn(1, 128, 19, 1)
  30. in_cache3 = torch.randn(1, 128, 19, 1)
  31. return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
  32. # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
  33. # import numpy as np
  34. # fbank = np.loadtxt(txt_file)
  35. # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
  36. # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
  37. # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
  38. # return (speech, speech_lengths)
  39. def get_input_names(self):
  40. return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
  41. def get_output_names(self):
  42. return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
  43. def get_dynamic_axes(self):
  44. return {
  45. 'speech': {
  46. 1: 'feats_length'
  47. },
  48. }