import copy import logging import os from argparse import Namespace from typing import Optional from typing import Tuple from typing import Union import humanfriendly import torch from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.modules.frontends.frontend import Frontend from funasr.modules.nets_utils import pad_list from funasr.utils.get_default_kwargs import get_default_kwargs def base_s3prl_setup(args): args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None) args.upstream_model_config = getattr(args, "upstream_model_config", None) args.upstream_refresh = getattr(args, "upstream_refresh", False) args.upstream_ckpt = getattr(args, "upstream_ckpt", None) args.init_ckpt = getattr(args, "init_ckpt", None) args.verbose = getattr(args, "verbose", False) args.tile_factor = getattr(args, "tile_factor", 1) return args class S3prlFrontend(AbsFrontend): """Speech Pretrained Representation frontend structure for ASR.""" def __init__( self, fs: Union[int, str] = 16000, frontend_conf: Optional[dict] = get_default_kwargs(Frontend), download_dir: str = None, multilayer_feature: bool = False, ): super().__init__() if isinstance(fs, str): fs = humanfriendly.parse_size(fs) if download_dir is not None: torch.hub.set_dir(download_dir) self.multilayer_feature = multilayer_feature self.upstream, self.featurizer = self._get_upstream(frontend_conf) self.pretrained_params = copy.deepcopy(self.upstream.state_dict()) self.output_dim = self.featurizer.output_dim self.frontend_type = "s3prl" self.hop_length = self.upstream.get_downsample_rates("key") def _get_upstream(self, frontend_conf): """Get S3PRL upstream model.""" s3prl_args = base_s3prl_setup( Namespace(**frontend_conf, device="cpu"), ) self.args = s3prl_args s3prl_path = None python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":") for p in python_path_list: if p.endswith("s3prl"): s3prl_path = p break assert s3prl_path is not None s3prl_upstream = torch.hub.load( s3prl_path, s3prl_args.upstream, ckpt=s3prl_args.upstream_ckpt, model_config=s3prl_args.upstream_model_config, refresh=s3prl_args.upstream_refresh, source="local", ).to("cpu") if getattr( s3prl_upstream, "model", None ) is not None and s3prl_upstream.model.__class__.__name__ in [ "Wav2Vec2Model", "HubertModel", ]: s3prl_upstream.model.encoder.layerdrop = 0.0 from s3prl.upstream.interfaces import Featurizer if self.multilayer_feature is None: feature_selection = "last_hidden_state" else: feature_selection = "hidden_states" s3prl_featurizer = Featurizer( upstream=s3prl_upstream, feature_selection=feature_selection, upstream_device="cpu", ) return s3prl_upstream, s3prl_featurizer def _tile_representations(self, feature): """Tile up the representations by `tile_factor`. Input - sequence of representations shape: (batch_size, seq_len, feature_dim) Output - sequence of tiled representations shape: (batch_size, seq_len * factor, feature_dim) """ assert ( len(feature.shape) == 3 ), "Input argument `feature` has invalid shape: {}".format(feature.shape) tiled_feature = feature.repeat(1, 1, self.args.tile_factor) tiled_feature = tiled_feature.reshape( feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2) ) return tiled_feature def output_size(self) -> int: return self.output_dim def forward( self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)] self.upstream.eval() with torch.no_grad(): feats = self.upstream(wavs) feats = self.featurizer(wavs, feats) if self.args.tile_factor != 1: feats = self._tile_representations(feats) input_feats = pad_list(feats, 0.0) feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long) # Saving CUDA Memory del feats return input_feats, feats_lens def reload_pretrained_parameters(self): self.upstream.load_state_dict(self.pretrained_params) logging.info("Pretrained S3PRL frontend model parameters reloaded!")