fused.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from funasr.models.frontend.abs_frontend import AbsFrontend
  2. from funasr.models.frontend.default import DefaultFrontend
  3. from funasr.models.frontend.s3prl import S3prlFrontend
  4. import numpy as np
  5. import torch
  6. from typeguard import check_argument_types
  7. from typing import Tuple
  8. class FusedFrontends(AbsFrontend):
  9. def __init__(
  10. self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
  11. ):
  12. assert check_argument_types()
  13. super().__init__()
  14. self.align_method = (
  15. align_method # fusing method : linear_projection only for now
  16. )
  17. self.proj_dim = proj_dim # dim of the projection done on each frontend
  18. self.frontends = [] # list of the frontends to combine
  19. for i, frontend in enumerate(frontends):
  20. frontend_type = frontend["frontend_type"]
  21. if frontend_type == "default":
  22. n_mels, fs, n_fft, win_length, hop_length = (
  23. frontend.get("n_mels", 80),
  24. fs,
  25. frontend.get("n_fft", 512),
  26. frontend.get("win_length"),
  27. frontend.get("hop_length", 128),
  28. )
  29. window, center, normalized, onesided = (
  30. frontend.get("window", "hann"),
  31. frontend.get("center", True),
  32. frontend.get("normalized", False),
  33. frontend.get("onesided", True),
  34. )
  35. fmin, fmax, htk, apply_stft = (
  36. frontend.get("fmin", None),
  37. frontend.get("fmax", None),
  38. frontend.get("htk", False),
  39. frontend.get("apply_stft", True),
  40. )
  41. self.frontends.append(
  42. DefaultFrontend(
  43. n_mels=n_mels,
  44. n_fft=n_fft,
  45. fs=fs,
  46. win_length=win_length,
  47. hop_length=hop_length,
  48. window=window,
  49. center=center,
  50. normalized=normalized,
  51. onesided=onesided,
  52. fmin=fmin,
  53. fmax=fmax,
  54. htk=htk,
  55. apply_stft=apply_stft,
  56. )
  57. )
  58. elif frontend_type == "s3prl":
  59. frontend_conf, download_dir, multilayer_feature = (
  60. frontend.get("frontend_conf"),
  61. frontend.get("download_dir"),
  62. frontend.get("multilayer_feature"),
  63. )
  64. self.frontends.append(
  65. S3prlFrontend(
  66. fs=fs,
  67. frontend_conf=frontend_conf,
  68. download_dir=download_dir,
  69. multilayer_feature=multilayer_feature,
  70. )
  71. )
  72. else:
  73. raise NotImplementedError # frontends are only default or s3prl
  74. self.frontends = torch.nn.ModuleList(self.frontends)
  75. self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
  76. self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
  77. if torch.cuda.is_available():
  78. dev = "cuda"
  79. else:
  80. dev = "cpu"
  81. if self.align_method == "linear_projection":
  82. self.projection_layers = [
  83. torch.nn.Linear(
  84. in_features=frontend.output_size(),
  85. out_features=self.factors[i] * self.proj_dim,
  86. )
  87. for i, frontend in enumerate(self.frontends)
  88. ]
  89. self.projection_layers = torch.nn.ModuleList(self.projection_layers)
  90. self.projection_layers = self.projection_layers.to(torch.device(dev))
  91. def output_size(self) -> int:
  92. return len(self.frontends) * self.proj_dim
  93. def forward(
  94. self, input: torch.Tensor, input_lengths: torch.Tensor
  95. ) -> Tuple[torch.Tensor, torch.Tensor]:
  96. # step 0 : get all frontends features
  97. self.feats = []
  98. for frontend in self.frontends:
  99. with torch.no_grad():
  100. input_feats, feats_lens = frontend.forward(input, input_lengths)
  101. self.feats.append([input_feats, feats_lens])
  102. if (
  103. self.align_method == "linear_projection"
  104. ): # TODO(Dan): to add other align methods
  105. # first step : projections
  106. self.feats_proj = []
  107. for i, frontend in enumerate(self.frontends):
  108. input_feats = self.feats[i][0]
  109. self.feats_proj.append(self.projection_layers[i](input_feats))
  110. # 2nd step : reshape
  111. self.feats_reshaped = []
  112. for i, frontend in enumerate(self.frontends):
  113. input_feats_proj = self.feats_proj[i]
  114. bs, nf, dim = input_feats_proj.shape
  115. input_feats_reshaped = torch.reshape(
  116. input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
  117. )
  118. self.feats_reshaped.append(input_feats_reshaped)
  119. # 3rd step : drop the few last frames
  120. m = min([x.shape[1] for x in self.feats_reshaped])
  121. self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
  122. input_feats = torch.cat(
  123. self.feats_final, dim=-1
  124. ) # change the input size of the preencoder : proj_dim * n_frontends
  125. feats_lens = torch.ones_like(self.feats[0][1]) * (m)
  126. else:
  127. raise NotImplementedError
  128. return input_feats, feats_lens