fused.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from funasr.frontends.default import DefaultFrontend
  2. from funasr.frontends.s3prl import S3prlFrontend
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from typing import Tuple
  7. class FusedFrontends(nn.Module):
  8. def __init__(
  9. self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
  10. ):
  11. super().__init__()
  12. self.align_method = (
  13. align_method # fusing method : linear_projection only for now
  14. )
  15. self.proj_dim = proj_dim # dim of the projection done on each frontend
  16. self.frontends = [] # list of the frontends to combine
  17. for i, frontend in enumerate(frontends):
  18. frontend_type = frontend["frontend_type"]
  19. if frontend_type == "default":
  20. n_mels, fs, n_fft, win_length, hop_length = (
  21. frontend.get("n_mels", 80),
  22. fs,
  23. frontend.get("n_fft", 512),
  24. frontend.get("win_length"),
  25. frontend.get("hop_length", 128),
  26. )
  27. window, center, normalized, onesided = (
  28. frontend.get("window", "hann"),
  29. frontend.get("center", True),
  30. frontend.get("normalized", False),
  31. frontend.get("onesided", True),
  32. )
  33. fmin, fmax, htk, apply_stft = (
  34. frontend.get("fmin", None),
  35. frontend.get("fmax", None),
  36. frontend.get("htk", False),
  37. frontend.get("apply_stft", True),
  38. )
  39. self.frontends.append(
  40. DefaultFrontend(
  41. n_mels=n_mels,
  42. n_fft=n_fft,
  43. fs=fs,
  44. win_length=win_length,
  45. hop_length=hop_length,
  46. window=window,
  47. center=center,
  48. normalized=normalized,
  49. onesided=onesided,
  50. fmin=fmin,
  51. fmax=fmax,
  52. htk=htk,
  53. apply_stft=apply_stft,
  54. )
  55. )
  56. elif frontend_type == "s3prl":
  57. frontend_conf, download_dir, multilayer_feature = (
  58. frontend.get("frontend_conf"),
  59. frontend.get("download_dir"),
  60. frontend.get("multilayer_feature"),
  61. )
  62. self.frontends.append(
  63. S3prlFrontend(
  64. fs=fs,
  65. frontend_conf=frontend_conf,
  66. download_dir=download_dir,
  67. multilayer_feature=multilayer_feature,
  68. )
  69. )
  70. else:
  71. raise NotImplementedError # frontends are only default or s3prl
  72. self.frontends = torch.nn.ModuleList(self.frontends)
  73. self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
  74. self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
  75. if torch.cuda.is_available():
  76. dev = "cuda"
  77. else:
  78. dev = "cpu"
  79. if self.align_method == "linear_projection":
  80. self.projection_layers = [
  81. torch.nn.Linear(
  82. in_features=frontend.output_size(),
  83. out_features=self.factors[i] * self.proj_dim,
  84. )
  85. for i, frontend in enumerate(self.frontends)
  86. ]
  87. self.projection_layers = torch.nn.ModuleList(self.projection_layers)
  88. self.projection_layers = self.projection_layers.to(torch.device(dev))
  89. def output_size(self) -> int:
  90. return len(self.frontends) * self.proj_dim
  91. def forward(
  92. self, input: torch.Tensor, input_lengths: torch.Tensor
  93. ) -> Tuple[torch.Tensor, torch.Tensor]:
  94. # step 0 : get all frontends features
  95. self.feats = []
  96. for frontend in self.frontends:
  97. with torch.no_grad():
  98. input_feats, feats_lens = frontend.forward(input, input_lengths)
  99. self.feats.append([input_feats, feats_lens])
  100. if (
  101. self.align_method == "linear_projection"
  102. ): # TODO(Dan): to add other align methods
  103. # first step : projections
  104. self.feats_proj = []
  105. for i, frontend in enumerate(self.frontends):
  106. input_feats = self.feats[i][0]
  107. self.feats_proj.append(self.projection_layers[i](input_feats))
  108. # 2nd step : reshape
  109. self.feats_reshaped = []
  110. for i, frontend in enumerate(self.frontends):
  111. input_feats_proj = self.feats_proj[i]
  112. bs, nf, dim = input_feats_proj.shape
  113. input_feats_reshaped = torch.reshape(
  114. input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
  115. )
  116. self.feats_reshaped.append(input_feats_reshaped)
  117. # 3rd step : drop the few last frames
  118. m = min([x.shape[1] for x in self.feats_reshaped])
  119. self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
  120. input_feats = torch.cat(
  121. self.feats_final, dim=-1
  122. ) # change the input size of the preencoder : proj_dim * n_frontends
  123. feats_lens = torch.ones_like(self.feats[0][1]) * (m)
  124. else:
  125. raise NotImplementedError
  126. return input_feats, feats_lens