"""Subsampling layer definition.""" import torch class OnnxConv2dSubsampling(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, model): """Construct an Conv2dSubsampling object.""" super().__init__() self.conv = model.conv self.out = model.out def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 4. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :-2:2][:, :-2:2] def __getitem__(self, key): """Get item. When reset_parameters() is called, if use_scaled_pos_enc is used, return the positioning encoding. """ if key != -1: raise NotImplementedError("Support only `-1` (for `reset_parameters`).") return self.out[key] class OnnxConv2dSubsampling2(torch.nn.Module): """Convolutional 2D subsampling (to 1/2 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, model): """Construct an Conv2dSubsampling object.""" super().__init__() self.conv = model.conv self.out = model.out def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 2. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 2. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :-2:2][:, :-2:1] def __getitem__(self, key): """Get item. When reset_parameters() is called, if use_scaled_pos_enc is used, return the positioning encoding. """ if key != -1: raise NotImplementedError("Support only `-1` (for `reset_parameters`).") return self.out[key] class OnnxConv2dSubsampling6(torch.nn.Module): """Convolutional 2D subsampling (to 1/6 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, model): """Construct an Conv2dSubsampling object.""" super().__init__() self.conv = model.conv self.out = model.out def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 6. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 6. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :-2:2][:, :-4:3] class OnnxConv2dSubsampling8(torch.nn.Module): """Convolutional 2D subsampling (to 1/8 length). Args: idim (int): Input dimension. odim (int): Output dimension. dropout_rate (float): Dropout rate. pos_enc (torch.nn.Module): Custom position encoding layer. """ def __init__(self, model): """Construct an Conv2dSubsampling object.""" super().__init__() self.conv = model.conv self.out = model.out def forward(self, x, x_mask): """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, time, idim). x_mask (torch.Tensor): Input mask (#batch, 1, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 8. torch.Tensor: Subsampled mask (#batch, 1, time'), where time' = time // 8. """ x = x.unsqueeze(1) # (b, c, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) if x_mask is None: return x, None return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]