mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add sond model
This commit is contained in:
parent
f6a1cdaf34
commit
7fe447185c
@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module):
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class FsmnFeedForward(torch.nn.Module):
|
||||
"""Position-wise feed forward for FSMN blocks.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace position-wise feed-forward network
|
||||
in FSMN block.
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
|
||||
"""Initialize FsmnFeedForward module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
out_chans (int): Number of output channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(FsmnFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
out_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
bias=False
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(hidden_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x, ilens=None):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of output tensors (B, T, out_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user