| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import copy
- import logging
- import os
- from argparse import Namespace
- from typing import Optional
- from typing import Tuple
- from typing import Union
- import humanfriendly
- import torch
- from typeguard import check_argument_types
- from funasr.models.frontend.abs_frontend import AbsFrontend
- from funasr.modules.frontends.frontend import Frontend
- from funasr.modules.nets_utils import pad_list
- from funasr.utils.get_default_kwargs import get_default_kwargs
- def base_s3prl_setup(args):
- args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
- args.upstream_model_config = getattr(args, "upstream_model_config", None)
- args.upstream_refresh = getattr(args, "upstream_refresh", False)
- args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
- args.init_ckpt = getattr(args, "init_ckpt", None)
- args.verbose = getattr(args, "verbose", False)
- args.tile_factor = getattr(args, "tile_factor", 1)
- return args
- class S3prlFrontend(AbsFrontend):
- """Speech Pretrained Representation frontend structure for ASR."""
- def __init__(
- self,
- fs: Union[int, str] = 16000,
- frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
- download_dir: str = None,
- multilayer_feature: bool = False,
- ):
- assert check_argument_types()
- super().__init__()
- if isinstance(fs, str):
- fs = humanfriendly.parse_size(fs)
- if download_dir is not None:
- torch.hub.set_dir(download_dir)
- self.multilayer_feature = multilayer_feature
- self.upstream, self.featurizer = self._get_upstream(frontend_conf)
- self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
- self.output_dim = self.featurizer.output_dim
- self.frontend_type = "s3prl"
- self.hop_length = self.upstream.get_downsample_rates("key")
- def _get_upstream(self, frontend_conf):
- """Get S3PRL upstream model."""
- s3prl_args = base_s3prl_setup(
- Namespace(**frontend_conf, device="cpu"),
- )
- self.args = s3prl_args
- s3prl_path = None
- python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
- for p in python_path_list:
- if p.endswith("s3prl"):
- s3prl_path = p
- break
- assert s3prl_path is not None
- s3prl_upstream = torch.hub.load(
- s3prl_path,
- s3prl_args.upstream,
- ckpt=s3prl_args.upstream_ckpt,
- model_config=s3prl_args.upstream_model_config,
- refresh=s3prl_args.upstream_refresh,
- source="local",
- ).to("cpu")
- if getattr(
- s3prl_upstream, "model", None
- ) is not None and s3prl_upstream.model.__class__.__name__ in [
- "Wav2Vec2Model",
- "HubertModel",
- ]:
- s3prl_upstream.model.encoder.layerdrop = 0.0
- from s3prl.upstream.interfaces import Featurizer
- if self.multilayer_feature is None:
- feature_selection = "last_hidden_state"
- else:
- feature_selection = "hidden_states"
- s3prl_featurizer = Featurizer(
- upstream=s3prl_upstream,
- feature_selection=feature_selection,
- upstream_device="cpu",
- )
- return s3prl_upstream, s3prl_featurizer
- def _tile_representations(self, feature):
- """Tile up the representations by `tile_factor`.
- Input - sequence of representations
- shape: (batch_size, seq_len, feature_dim)
- Output - sequence of tiled representations
- shape: (batch_size, seq_len * factor, feature_dim)
- """
- assert (
- len(feature.shape) == 3
- ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
- tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
- tiled_feature = tiled_feature.reshape(
- feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
- )
- return tiled_feature
- def output_size(self) -> int:
- return self.output_dim
- def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
- self.upstream.eval()
- with torch.no_grad():
- feats = self.upstream(wavs)
- feats = self.featurizer(wavs, feats)
- if self.args.tile_factor != 1:
- feats = self._tile_representations(feats)
- input_feats = pad_list(feats, 0.0)
- feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
- # Saving CUDA Memory
- del feats
- return input_feats, feats_lens
- def reload_pretrained_parameters(self):
- self.upstream.load_state_dict(self.pretrained_params)
- logging.info("Pretrained S3PRL frontend model parameters reloaded!")
|