| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from funasr.models.frontend.abs_frontend import AbsFrontend
- from funasr.models.frontend.default import DefaultFrontend
- from funasr.models.frontend.s3prl import S3prlFrontend
- import numpy as np
- import torch
- from typing import Tuple
- class FusedFrontends(AbsFrontend):
- def __init__(
- self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
- ):
- super().__init__()
- self.align_method = (
- align_method # fusing method : linear_projection only for now
- )
- self.proj_dim = proj_dim # dim of the projection done on each frontend
- self.frontends = [] # list of the frontends to combine
- for i, frontend in enumerate(frontends):
- frontend_type = frontend["frontend_type"]
- if frontend_type == "default":
- n_mels, fs, n_fft, win_length, hop_length = (
- frontend.get("n_mels", 80),
- fs,
- frontend.get("n_fft", 512),
- frontend.get("win_length"),
- frontend.get("hop_length", 128),
- )
- window, center, normalized, onesided = (
- frontend.get("window", "hann"),
- frontend.get("center", True),
- frontend.get("normalized", False),
- frontend.get("onesided", True),
- )
- fmin, fmax, htk, apply_stft = (
- frontend.get("fmin", None),
- frontend.get("fmax", None),
- frontend.get("htk", False),
- frontend.get("apply_stft", True),
- )
- self.frontends.append(
- DefaultFrontend(
- n_mels=n_mels,
- n_fft=n_fft,
- fs=fs,
- win_length=win_length,
- hop_length=hop_length,
- window=window,
- center=center,
- normalized=normalized,
- onesided=onesided,
- fmin=fmin,
- fmax=fmax,
- htk=htk,
- apply_stft=apply_stft,
- )
- )
- elif frontend_type == "s3prl":
- frontend_conf, download_dir, multilayer_feature = (
- frontend.get("frontend_conf"),
- frontend.get("download_dir"),
- frontend.get("multilayer_feature"),
- )
- self.frontends.append(
- S3prlFrontend(
- fs=fs,
- frontend_conf=frontend_conf,
- download_dir=download_dir,
- multilayer_feature=multilayer_feature,
- )
- )
- else:
- raise NotImplementedError # frontends are only default or s3prl
- self.frontends = torch.nn.ModuleList(self.frontends)
- self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
- self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
- if torch.cuda.is_available():
- dev = "cuda"
- else:
- dev = "cpu"
- if self.align_method == "linear_projection":
- self.projection_layers = [
- torch.nn.Linear(
- in_features=frontend.output_size(),
- out_features=self.factors[i] * self.proj_dim,
- )
- for i, frontend in enumerate(self.frontends)
- ]
- self.projection_layers = torch.nn.ModuleList(self.projection_layers)
- self.projection_layers = self.projection_layers.to(torch.device(dev))
- def output_size(self) -> int:
- return len(self.frontends) * self.proj_dim
- def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # step 0 : get all frontends features
- self.feats = []
- for frontend in self.frontends:
- with torch.no_grad():
- input_feats, feats_lens = frontend.forward(input, input_lengths)
- self.feats.append([input_feats, feats_lens])
- if (
- self.align_method == "linear_projection"
- ): # TODO(Dan): to add other align methods
- # first step : projections
- self.feats_proj = []
- for i, frontend in enumerate(self.frontends):
- input_feats = self.feats[i][0]
- self.feats_proj.append(self.projection_layers[i](input_feats))
- # 2nd step : reshape
- self.feats_reshaped = []
- for i, frontend in enumerate(self.frontends):
- input_feats_proj = self.feats_proj[i]
- bs, nf, dim = input_feats_proj.shape
- input_feats_reshaped = torch.reshape(
- input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
- )
- self.feats_reshaped.append(input_feats_reshaped)
- # 3rd step : drop the few last frames
- m = min([x.shape[1] for x in self.feats_reshaped])
- self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
- input_feats = torch.cat(
- self.feats_final, dim=-1
- ) # change the input size of the preencoder : proj_dim * n_frontends
- feats_lens = torch.ones_like(self.feats[0][1]) * (m)
- else:
- raise NotImplementedError
- return input_feats, feats_lens
|