mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
410 lines
14 KiB
Python
410 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2019 Shigeki Karita
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""Subsampling layer definition."""
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from funasr.modules.embedding import PositionalEncoding
|
|
import logging
|
|
from funasr.modules.streaming_utils.utils import sequence_mask
|
|
class TooShortUttError(Exception):
|
|
"""Raised when the utt is too short for subsampling.
|
|
|
|
Args:
|
|
message (str): Message for error catch
|
|
actual_size (int): the short size that cannot pass the subsampling
|
|
limit (int): the limit size for subsampling
|
|
|
|
"""
|
|
|
|
def __init__(self, message, actual_size, limit):
|
|
"""Construct a TooShortUttError for error handler."""
|
|
super().__init__(message)
|
|
self.actual_size = actual_size
|
|
self.limit = limit
|
|
|
|
|
|
def check_short_utt(ins, size):
|
|
"""Check if the utterance is too short for subsampling."""
|
|
if isinstance(ins, Conv2dSubsampling2) and size < 3:
|
|
return True, 3
|
|
if isinstance(ins, Conv2dSubsampling) and size < 7:
|
|
return True, 7
|
|
if isinstance(ins, Conv2dSubsampling6) and size < 11:
|
|
return True, 11
|
|
if isinstance(ins, Conv2dSubsampling8) and size < 15:
|
|
return True, 15
|
|
return False, -1
|
|
|
|
|
|
class Conv2dSubsampling(torch.nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/4 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
|
"""Construct an Conv2dSubsampling object."""
|
|
super(Conv2dSubsampling, self).__init__()
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.out = torch.nn.Sequential(
|
|
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
|
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
|
)
|
|
|
|
def forward(self, x, x_mask):
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
where time' = time // 4.
|
|
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
where time' = time // 4.
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
if x_mask is None:
|
|
return x, None
|
|
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
|
|
|
|
def __getitem__(self, key):
|
|
"""Get item.
|
|
|
|
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
|
return the positioning encoding.
|
|
|
|
"""
|
|
if key != -1:
|
|
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
|
return self.out[key]
|
|
|
|
class Conv2dSubsamplingPad(torch.nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/4 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
|
"""Construct an Conv2dSubsampling object."""
|
|
super(Conv2dSubsamplingPad, self).__init__()
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, odim, 3, 2, padding=(0, 0)),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 3, 2, padding=(0, 0)),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.out = torch.nn.Sequential(
|
|
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
|
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
|
)
|
|
self.pad_fn = torch.nn.ConstantPad1d((0, 4), 0.0)
|
|
|
|
def forward(self, x, x_mask):
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
where time' = time // 4.
|
|
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
where time' = time // 4.
|
|
|
|
"""
|
|
x = x.transpose(1, 2)
|
|
x = self.pad_fn(x)
|
|
x = x.transpose(1, 2)
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
if x_mask is None:
|
|
return x, None
|
|
x_len = torch.sum(x_mask[:, 0, :], dim=-1)
|
|
x_len = (x_len - 1) // 2 + 1
|
|
x_len = (x_len - 1) // 2 + 1
|
|
mask = sequence_mask(x_len, None, x_len.dtype, x[0].device)
|
|
return x, mask[:, None, :]
|
|
|
|
def __getitem__(self, key):
|
|
"""Get item.
|
|
|
|
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
|
return the positioning encoding.
|
|
|
|
"""
|
|
if key != -1:
|
|
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
|
return self.out[key]
|
|
|
|
|
|
class Conv2dSubsampling2(torch.nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/2 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
|
"""Construct an Conv2dSubsampling2 object."""
|
|
super(Conv2dSubsampling2, self).__init__()
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 3, 1),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.out = torch.nn.Sequential(
|
|
torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
|
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
|
)
|
|
|
|
def forward(self, x, x_mask):
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
where time' = time // 2.
|
|
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
where time' = time // 2.
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
if x_mask is None:
|
|
return x, None
|
|
return x, x_mask[:, :, :-2:2][:, :, :-2:1]
|
|
|
|
def __getitem__(self, key):
|
|
"""Get item.
|
|
|
|
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
|
return the positioning encoding.
|
|
|
|
"""
|
|
if key != -1:
|
|
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
|
return self.out[key]
|
|
|
|
|
|
class Conv2dSubsampling6(torch.nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/6 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
|
"""Construct an Conv2dSubsampling6 object."""
|
|
super(Conv2dSubsampling6, self).__init__()
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 5, 3),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.out = torch.nn.Sequential(
|
|
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
|
)
|
|
|
|
def forward(self, x, x_mask):
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
where time' = time // 6.
|
|
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
where time' = time // 6.
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
if x_mask is None:
|
|
return x, None
|
|
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
|
|
|
|
|
class Conv2dSubsampling8(torch.nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/8 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
|
"""Construct an Conv2dSubsampling8 object."""
|
|
super(Conv2dSubsampling8, self).__init__()
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.out = torch.nn.Sequential(
|
|
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
|
)
|
|
|
|
def forward(self, x, x_mask):
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
where time' = time // 8.
|
|
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
where time' = time // 8.
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
if x_mask is None:
|
|
return x, None
|
|
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
|
|
|
class Conv1dSubsampling(torch.nn.Module):
|
|
"""Convolutional 1D subsampling (to 1/2 length).
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
odim (int): Output dimension.
|
|
dropout_rate (float): Dropout rate.
|
|
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim, odim, kernel_size, stride, pad,
|
|
tf2torch_tensor_name_prefix_torch: str = "stride_conv",
|
|
tf2torch_tensor_name_prefix_tf: str = "seq2seq/proj_encoder/downsampling",
|
|
):
|
|
super(Conv1dSubsampling, self).__init__()
|
|
self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride)
|
|
self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0)
|
|
self.stride = stride
|
|
self.odim = odim
|
|
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
|
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
|
|
|
def output_size(self) -> int:
|
|
return self.odim
|
|
|
|
def forward(self, x, x_len):
|
|
"""Subsample x.
|
|
|
|
"""
|
|
x = x.transpose(1, 2) # (b, d ,t)
|
|
x = self.pad_fn(x)
|
|
x = F.relu(self.conv(x))
|
|
x = x.transpose(1, 2) # (b, t ,d)
|
|
|
|
if x_len is None:
|
|
|
|
return x, None
|
|
x_len = (x_len - 1) // self.stride + 1
|
|
return x, x_len
|
|
|
|
def gen_tf2torch_map_dict(self):
|
|
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
|
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
|
map_dict_local = {
|
|
## predictor
|
|
"{}.conv.weight".format(tensor_name_prefix_torch):
|
|
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
|
"squeeze": None,
|
|
"transpose": (2, 1, 0),
|
|
}, # (256,256,3),(3,256,256)
|
|
"{}.conv.bias".format(tensor_name_prefix_torch):
|
|
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
|
"squeeze": None,
|
|
"transpose": None,
|
|
}, # (256,),(256,)
|
|
}
|
|
return map_dict_local
|
|
|
|
def convert_tf2torch(self,
|
|
var_dict_tf,
|
|
var_dict_torch,
|
|
):
|
|
|
|
map_dict = self.gen_tf2torch_map_dict()
|
|
|
|
var_dict_torch_update = dict()
|
|
for name in sorted(var_dict_torch.keys(), reverse=False):
|
|
names = name.split('.')
|
|
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
|
name_tf = map_dict[name]["name"]
|
|
data_tf = var_dict_tf[name_tf]
|
|
if map_dict[name]["squeeze"] is not None:
|
|
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
|
if map_dict[name]["transpose"] is not None:
|
|
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
|
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
|
|
|
var_dict_torch_update[name] = data_tf
|
|
|
|
logging.info(
|
|
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
|
var_dict_tf[name_tf].shape))
|
|
return var_dict_torch_update
|
|
|