#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Subsampling layer definition.""" import numpy as np import torch import torch.nn.functional as F from funasr.modules.embedding import PositionalEncoding import logging from funasr.modules.streaming_utils.utils import sequence_mask class TooShortUttError(Exception): """Raised when the utt is too short for subsampling. Args: message (str): Message for error catch actual_size (int): the short size that cannot pass the subsampling limit (int): the limit size for subsampling """ def __init__(self, message, actual_size, limit): """Construct a TooShortUttError for error handler.""" super().__init__(message) self.actual_size = actual_size self.limit = limit def check_short_utt(ins, size): """Check if the utterance is too short for subsampling.""" if isinstance(ins, Conv2dSubsampling2) and size < 3: return True, 3 if isinstance(ins, Conv2dSubsampling) and size < 7: return True, 7 if isinstance(ins, Conv2dSubsampling6) and size < 11: return True, 11 if isinstance(ins, Conv2dSubsampling8) and size < 15: return True, 15 return False, -1 class Conv2dSubsampling(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling object.""" super(Conv2dSubsampling, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 4. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :, :-2:2][:, :, :-2:2] def __getitem__(self, key): """Get item. When reset_parameters() is called, if use_scaled_pos_enc is used, return the positioning encoding. """ if key != -1: raise NotImplementedError("Support only `-1` (for `reset_parameters`).") return self.out[key] class Conv2dSubsamplingPad(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling object.""" super(Conv2dSubsamplingPad, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2, padding=(0, 0)), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2, padding=(0, 0)), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) self.pad_fn = torch.nn.ConstantPad1d((0, 4), 0.0) def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 4. """ x = x.transpose(1, 2) x = self.pad_fn(x) x = x.transpose(1, 2) x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None x_len = torch.sum(x_mask[:, 0, :], dim=-1) x_len = (x_len - 1) // 2 + 1 x_len = (x_len - 1) // 2 + 1 mask = sequence_mask(x_len, None, x_len.dtype, x[0].device) return x, mask[:, None, :] def __getitem__(self, key): """Get item. When reset_parameters() is called, if use_scaled_pos_enc is used, return the positioning encoding. """ if key != -1: raise NotImplementedError("Support only `-1` (for `reset_parameters`).") return self.out[key] class Conv2dSubsampling2(torch.nn.Module): """Convolutional 2D subsampling (to 1/2 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling2 object.""" super(Conv2dSubsampling2, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 1), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 2. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 2. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :, :-2:2][:, :, :-2:1] def __getitem__(self, key): """Get item. When reset_parameters() is called, if use_scaled_pos_enc is used, return the positioning encoding. """ if key != -1: raise NotImplementedError("Support only `-1` (for `reset_parameters`).") return self.out[key] class Conv2dSubsampling6(torch.nn.Module): """Convolutional 2D subsampling (to 1/6 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling6 object.""" super(Conv2dSubsampling6, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 5, 3), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 6. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 6. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :, :-2:2][:, :, :-4:3] class Conv2dSubsampling8(torch.nn.Module): """Convolutional 2D subsampling (to 1/8 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, dropout_rate, pos_enc=None): """Construct an Conv2dSubsampling8 object.""" super(Conv2dSubsampling8, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU(), ) self.out = torch.nn.Sequential( torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 8. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 8. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] class Conv1dSubsampling(torch.nn.Module): """Convolutional 1D subsampling (to 1/2 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, idim, odim, kernel_size, stride, pad, tf2torch_tensor_name_prefix_torch: str = "stride_conv", tf2torch_tensor_name_prefix_tf: str = "seq2seq/proj_encoder/downsampling", ): super(Conv1dSubsampling, self).__init__() self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride) self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0) self.stride = stride self.odim = odim self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf def output_size(self) -> int: return self.odim def forward(self, x, x_len): """Subsample x. """ x = x.transpose(1, 2) # (b, d ,t) x = self.pad_fn(x) x = F.relu(self.conv(x)) x = x.transpose(1, 2) # (b, t ,d) if x_len is None: return x, None x_len = (x_len - 1) // self.stride + 1 return x, x_len def gen_tf2torch_map_dict(self): tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf map_dict_local = { ## predictor "{}.conv.weight".format(tensor_name_prefix_torch): {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf), "squeeze": None, "transpose": (2, 1, 0), }, # (256,256,3),(3,256,256) "{}.conv.bias".format(tensor_name_prefix_torch): {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) } return map_dict_local def convert_tf2torch(self, var_dict_tf, var_dict_torch, ): map_dict = self.gen_tf2torch_map_dict() var_dict_torch_update = dict() for name in sorted(var_dict_torch.keys(), reverse=False): names = name.split('.') if names[0] == self.tf2torch_tensor_name_prefix_torch: name_tf = map_dict[name]["name"] data_tf = var_dict_tf[name_tf] if map_dict[name]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) if map_dict[name]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) return var_dict_torch_update