mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
分角色语音识别支持更多的模型
This commit is contained in:
parent
73613cefc9
commit
18b1449d1f
@ -51,10 +51,10 @@ from funasr.utils.vad_utils import slice_padding_fbank
|
|||||||
from funasr.utils.speaker_utils import (check_audio_list,
|
from funasr.utils.speaker_utils import (check_audio_list,
|
||||||
sv_preprocess,
|
sv_preprocess,
|
||||||
sv_chunk,
|
sv_chunk,
|
||||||
CAMPPlus,
|
|
||||||
extract_feature,
|
extract_feature,
|
||||||
postprocess,
|
postprocess,
|
||||||
distribute_spk, ERes2Net)
|
distribute_spk)
|
||||||
|
import funasr.modules.cnn as sv_module
|
||||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||||
from funasr.utils.cluster_backend import ClusterBackend
|
from funasr.utils.cluster_backend import ClusterBackend
|
||||||
from funasr.utils.modelscope_utils import get_cache_dir
|
from funasr.utils.modelscope_utils import get_cache_dir
|
||||||
@ -818,11 +818,15 @@ def inference_paraformer_vad_speaker(
|
|||||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin")
|
sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml")
|
||||||
if not os.path.exists(sv_model_file):
|
if not os.path.exists(sv_model_config_path):
|
||||||
sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt")
|
sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}}
|
||||||
if not os.path.exists(sv_model_file):
|
else:
|
||||||
raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file))
|
with open(sv_model_config_path, 'r') as f:
|
||||||
|
sv_model_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
if sv_model_config['models_config'] is None:
|
||||||
|
sv_model_config['models_config'] = {}
|
||||||
|
sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file'])
|
||||||
|
|
||||||
if param_dict is not None:
|
if param_dict is not None:
|
||||||
hotword_list_or_file = param_dict.get('hotword')
|
hotword_list_or_file = param_dict.get('hotword')
|
||||||
@ -949,14 +953,11 @@ def inference_paraformer_vad_speaker(
|
|||||||
##################################
|
##################################
|
||||||
# load sv model
|
# load sv model
|
||||||
sv_model_dict = torch.load(sv_model_file)
|
sv_model_dict = torch.load(sv_model_file)
|
||||||
print(f'load sv model params: {sv_model_file}')
|
sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
|
||||||
if os.path.basename(sv_model_file) == "campplus_cn_common.bin":
|
|
||||||
sv_model = CAMPPlus()
|
|
||||||
else:
|
|
||||||
sv_model = ERes2Net()
|
|
||||||
if ngpu > 0:
|
if ngpu > 0:
|
||||||
sv_model.cuda()
|
sv_model.cuda()
|
||||||
sv_model.load_state_dict(sv_model_dict)
|
sv_model.load_state_dict(sv_model_dict)
|
||||||
|
print(f'load sv model params: {sv_model_file}')
|
||||||
sv_model.eval()
|
sv_model.eval()
|
||||||
cb_model = ClusterBackend()
|
cb_model = ClusterBackend()
|
||||||
vad_segments = []
|
vad_segments = []
|
||||||
|
|||||||
108
funasr/models/pooling/pooling_layers.py
Normal file
108
funasr/models/pooling/pooling_layers.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class TAP(nn.Module):
|
||||||
|
"""
|
||||||
|
Temporal average pooling, only first-order mean is considered
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(TAP, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pooling_mean = x.mean(dim=-1)
|
||||||
|
# To be compatable with 2D input
|
||||||
|
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||||
|
return pooling_mean
|
||||||
|
|
||||||
|
|
||||||
|
class TSDP(nn.Module):
|
||||||
|
"""
|
||||||
|
Temporal standard deviation pooling, only second-order std is considered
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(TSDP, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# The last dimension is the temporal axis
|
||||||
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||||||
|
pooling_std = pooling_std.flatten(start_dim=1)
|
||||||
|
return pooling_std
|
||||||
|
|
||||||
|
|
||||||
|
class TSTP(nn.Module):
|
||||||
|
"""
|
||||||
|
Temporal statistics pooling, concatenate mean and std, which is used in
|
||||||
|
x-vector
|
||||||
|
Comment: simple concatenation can not make full use of both statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(TSTP, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# The last dimension is the temporal axis
|
||||||
|
pooling_mean = x.mean(dim=-1)
|
||||||
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||||||
|
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||||
|
pooling_std = pooling_std.flatten(start_dim=1)
|
||||||
|
|
||||||
|
stats = torch.cat((pooling_mean, pooling_std), 1)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class ASTP(nn.Module):
|
||||||
|
""" Attentive statistics pooling: Channel- and context-dependent
|
||||||
|
statistics pooling, first used in ECAPA_TDNN.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||||||
|
super(ASTP, self).__init__()
|
||||||
|
self.global_context_att = global_context_att
|
||||||
|
|
||||||
|
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||||||
|
# need to transpose inputs.
|
||||||
|
if global_context_att:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim * 3, bottleneck_dim,
|
||||||
|
kernel_size=1) # equals W and b in the paper
|
||||||
|
else:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim, bottleneck_dim,
|
||||||
|
kernel_size=1) # equals W and b in the paper
|
||||||
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
||||||
|
kernel_size=1) # equals V and k in the paper
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
||||||
|
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
||||||
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
||||||
|
"""
|
||||||
|
if len(x.shape) == 4:
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
||||||
|
assert len(x.shape) == 3
|
||||||
|
|
||||||
|
if self.global_context_att:
|
||||||
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||||
|
context_std = torch.sqrt(
|
||||||
|
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||||
|
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||||
|
else:
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
# DON'T use ReLU here! ReLU may be hard to converge.
|
||||||
|
alpha = torch.tanh(
|
||||||
|
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
|
var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
||||||
|
std = torch.sqrt(var.clamp(min=1e-10))
|
||||||
|
return torch.cat([mean, std], dim=1)
|
||||||
124
funasr/modules/cnn/DTDNN.py
Normal file
124
funasr/modules/cnn/DTDNN.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from funasr.modules.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
|
||||||
|
BasicResBlock, get_nonlinear
|
||||||
|
|
||||||
|
|
||||||
|
class FCM(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
block=BasicResBlock,
|
||||||
|
num_blocks=[2, 2],
|
||||||
|
m_channels=32,
|
||||||
|
feat_dim=80):
|
||||||
|
super(FCM, self).__init__()
|
||||||
|
self.in_planes = m_channels
|
||||||
|
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||||
|
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(m_channels)
|
||||||
|
self.out_channels = m_channels * (feat_dim // 8)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.layer1(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
out = F.relu(self.bn2(self.conv2(out)))
|
||||||
|
|
||||||
|
shape = out.shape
|
||||||
|
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CAMPPlus(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
feat_dim=80,
|
||||||
|
embedding_size=192,
|
||||||
|
growth_rate=32,
|
||||||
|
bn_size=4,
|
||||||
|
init_channels=128,
|
||||||
|
config_str='batchnorm-relu',
|
||||||
|
memory_efficient=True,
|
||||||
|
output_level='segment'):
|
||||||
|
super(CAMPPlus, self).__init__()
|
||||||
|
|
||||||
|
self.head = FCM(feat_dim=feat_dim)
|
||||||
|
channels = self.head.out_channels
|
||||||
|
self.output_level = output_level
|
||||||
|
|
||||||
|
self.xvector = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
|
||||||
|
('tdnn',
|
||||||
|
TDNNLayer(channels,
|
||||||
|
init_channels,
|
||||||
|
5,
|
||||||
|
stride=2,
|
||||||
|
dilation=1,
|
||||||
|
padding=-1,
|
||||||
|
config_str=config_str)),
|
||||||
|
]))
|
||||||
|
channels = init_channels
|
||||||
|
for i, (num_layers, kernel_size,
|
||||||
|
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
||||||
|
block = CAMDenseTDNNBlock(num_layers=num_layers,
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=growth_rate,
|
||||||
|
bn_channels=bn_size * growth_rate,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dilation=dilation,
|
||||||
|
config_str=config_str,
|
||||||
|
memory_efficient=memory_efficient)
|
||||||
|
self.xvector.add_module('block%d' % (i + 1), block)
|
||||||
|
channels = channels + num_layers * growth_rate
|
||||||
|
self.xvector.add_module(
|
||||||
|
'transit%d' % (i + 1),
|
||||||
|
TransitLayer(channels,
|
||||||
|
channels // 2,
|
||||||
|
bias=False,
|
||||||
|
config_str=config_str))
|
||||||
|
channels //= 2
|
||||||
|
|
||||||
|
self.xvector.add_module(
|
||||||
|
'out_nonlinear', get_nonlinear(config_str, channels))
|
||||||
|
|
||||||
|
if self.output_level == 'segment':
|
||||||
|
self.xvector.add_module('stats', StatsPool())
|
||||||
|
self.xvector.add_module(
|
||||||
|
'dense',
|
||||||
|
DenseLayer(
|
||||||
|
channels * 2, embedding_size, config_str='batchnorm_'))
|
||||||
|
else:
|
||||||
|
assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||||
|
nn.init.kaiming_normal_(m.weight.data)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||||
|
x = self.head(x)
|
||||||
|
x = self.xvector(x)
|
||||||
|
if self.output_level == 'frame':
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return x
|
||||||
420
funasr/modules/cnn/ResNet.py
Normal file
420
funasr/modules/cnn/ResNet.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||||
|
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||||
|
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||||
|
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||||
|
ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||||
|
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import funasr.models.pooling.pooling_layers as pooling_layers
|
||||||
|
from funasr.modules.cnn.fusion import AFF
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(nn.Hardtanh):
|
||||||
|
|
||||||
|
def __init__(self, inplace=False):
|
||||||
|
super(ReLU, self).__init__(0, 20, inplace)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
inplace_str = 'inplace' if self.inplace else ''
|
||||||
|
return self.__class__.__name__ + ' (' \
|
||||||
|
+ inplace_str + ')'
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"1x1 convolution without padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||||
|
padding=0, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"3x3 convolution with padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockERes2Net(nn.Module):
|
||||||
|
expansion = 2
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||||
|
super(BasicBlockERes2Net, self).__init__()
|
||||||
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
|
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||||
|
self.nums = scale
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
convs.append(conv3x3(width, width))
|
||||||
|
bns.append(nn.BatchNorm2d(width))
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList(bns)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
self.stride = stride
|
||||||
|
self.width = width
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
spx = torch.split(out, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = sp + spx[i]
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.relu(self.bns[i](sp))
|
||||||
|
if i == 0:
|
||||||
|
out = sp
|
||||||
|
else:
|
||||||
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||||
|
expansion = 2
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||||
|
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||||
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
|
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||||
|
self.nums = scale
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
fuse_models = []
|
||||||
|
bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
convs.append(conv3x3(width, width))
|
||||||
|
bns.append(nn.BatchNorm2d(width))
|
||||||
|
for j in range(self.nums - 1):
|
||||||
|
fuse_models.append(AFF(channels=width))
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList(bns)
|
||||||
|
self.fuse_models = nn.ModuleList(fuse_models)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
self.stride = stride
|
||||||
|
self.width = width
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
spx = torch.split(out, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||||
|
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.relu(self.bns[i](sp))
|
||||||
|
if i == 0:
|
||||||
|
out = sp
|
||||||
|
else:
|
||||||
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ERes2Net(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
block=BasicBlockERes2Net,
|
||||||
|
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||||
|
num_blocks=[3, 4, 6, 3],
|
||||||
|
m_channels=32,
|
||||||
|
feat_dim=80,
|
||||||
|
embedding_size=192,
|
||||||
|
pooling_func='TSTP',
|
||||||
|
two_emb_layer=False):
|
||||||
|
super(ERes2Net, self).__init__()
|
||||||
|
self.in_planes = m_channels
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||||
|
self.two_emb_layer = two_emb_layer
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1,
|
||||||
|
m_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||||
|
self.layer1 = self._make_layer(block,
|
||||||
|
m_channels,
|
||||||
|
num_blocks[0],
|
||||||
|
stride=1)
|
||||||
|
self.layer2 = self._make_layer(block,
|
||||||
|
m_channels * 2,
|
||||||
|
num_blocks[1],
|
||||||
|
stride=2)
|
||||||
|
self.layer3 = self._make_layer(block_fuse,
|
||||||
|
m_channels * 4,
|
||||||
|
num_blocks[2],
|
||||||
|
stride=2)
|
||||||
|
self.layer4 = self._make_layer(block_fuse,
|
||||||
|
m_channels * 8,
|
||||||
|
num_blocks[3],
|
||||||
|
stride=2)
|
||||||
|
|
||||||
|
# Downsampling module for each layer
|
||||||
|
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
|
||||||
|
bias=False)
|
||||||
|
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
# Bottom-up fusion module
|
||||||
|
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||||
|
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||||
|
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||||
|
|
||||||
|
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||||
|
self.pool = getattr(pooling_layers, pooling_func)(
|
||||||
|
in_dim=self.stats_dim * block.expansion)
|
||||||
|
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||||
|
embedding_size)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
|
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||||
|
else:
|
||||||
|
self.seg_bn_1 = nn.Identity()
|
||||||
|
self.seg_2 = nn.Identity()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||||
|
x = x.unsqueeze_(1)
|
||||||
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
out1 = self.layer1(out)
|
||||||
|
out2 = self.layer2(out1)
|
||||||
|
out1_downsample = self.layer1_downsample(out1)
|
||||||
|
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||||
|
out3 = self.layer3(out2)
|
||||||
|
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||||
|
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||||
|
out4 = self.layer4(out3)
|
||||||
|
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||||
|
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||||
|
stats = self.pool(fuse_out1234)
|
||||||
|
|
||||||
|
embed_a = self.seg_1(stats)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
out = F.relu(embed_a)
|
||||||
|
out = self.seg_bn_1(out)
|
||||||
|
embed_b = self.seg_2(out)
|
||||||
|
return embed_b
|
||||||
|
else:
|
||||||
|
return embed_a
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockRes2Net(nn.Module):
|
||||||
|
expansion = 2
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||||
|
super(BasicBlockRes2Net, self).__init__()
|
||||||
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
|
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||||
|
self.nums = scale - 1
|
||||||
|
convs = []
|
||||||
|
bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
convs.append(conv3x3(width, width))
|
||||||
|
bns.append(nn.BatchNorm2d(width))
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList(bns)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
self.stride = stride
|
||||||
|
self.width = width
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
spx = torch.split(out, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = sp + spx[i]
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.relu(self.bns[i](sp))
|
||||||
|
if i == 0:
|
||||||
|
out = sp
|
||||||
|
else:
|
||||||
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
out = torch.cat((out, spx[self.nums]), 1)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Res2Net(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
block=BasicBlockRes2Net,
|
||||||
|
num_blocks=[3, 4, 6, 3],
|
||||||
|
m_channels=32,
|
||||||
|
feat_dim=80,
|
||||||
|
embedding_size=192,
|
||||||
|
pooling_func='TSTP',
|
||||||
|
two_emb_layer=False):
|
||||||
|
super(Res2Net, self).__init__()
|
||||||
|
self.in_planes = m_channels
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||||
|
self.two_emb_layer = two_emb_layer
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1,
|
||||||
|
m_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||||
|
self.layer1 = self._make_layer(block,
|
||||||
|
m_channels,
|
||||||
|
num_blocks[0],
|
||||||
|
stride=1)
|
||||||
|
self.layer2 = self._make_layer(block,
|
||||||
|
m_channels * 2,
|
||||||
|
num_blocks[1],
|
||||||
|
stride=2)
|
||||||
|
self.layer3 = self._make_layer(block,
|
||||||
|
m_channels * 4,
|
||||||
|
num_blocks[2],
|
||||||
|
stride=2)
|
||||||
|
self.layer4 = self._make_layer(block,
|
||||||
|
m_channels * 8,
|
||||||
|
num_blocks[3],
|
||||||
|
stride=2)
|
||||||
|
|
||||||
|
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||||
|
self.pool = getattr(pooling_layers, pooling_func)(
|
||||||
|
in_dim=self.stats_dim * block.expansion)
|
||||||
|
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||||
|
embedding_size)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
|
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||||
|
else:
|
||||||
|
self.seg_bn_1 = nn.Identity()
|
||||||
|
self.seg_2 = nn.Identity()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||||
|
|
||||||
|
x = x.unsqueeze_(1)
|
||||||
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.layer1(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
out = self.layer3(out)
|
||||||
|
out = self.layer4(out)
|
||||||
|
|
||||||
|
stats = self.pool(out)
|
||||||
|
|
||||||
|
embed_a = self.seg_1(stats)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
out = F.relu(embed_a)
|
||||||
|
out = self.seg_bn_1(out)
|
||||||
|
embed_b = self.seg_2(out)
|
||||||
|
return embed_b
|
||||||
|
else:
|
||||||
|
return embed_a
|
||||||
273
funasr/modules/cnn/ResNet_aug.py
Normal file
273
funasr/modules/cnn/ResNet_aug.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||||
|
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||||
|
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||||
|
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||||
|
ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||||
|
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import funasr.models.pooling.pooling_layers as pooling_layers
|
||||||
|
from funasr.modules.cnn.fusion import AFF
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(nn.Hardtanh):
|
||||||
|
|
||||||
|
def __init__(self, inplace=False):
|
||||||
|
super(ReLU, self).__init__(0, 20, inplace)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
inplace_str = 'inplace' if self.inplace else ''
|
||||||
|
return self.__class__.__name__ + ' (' \
|
||||||
|
+ inplace_str + ')'
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"1x1 convolution without padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||||
|
padding=0, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"3x3 convolution with padding"
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockERes2Net(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||||
|
super(BasicBlockERes2Net, self).__init__()
|
||||||
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
|
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||||
|
self.nums = scale
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
convs.append(conv3x3(width, width))
|
||||||
|
bns.append(nn.BatchNorm2d(width))
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList(bns)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
self.stride = stride
|
||||||
|
self.width = width
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
spx = torch.split(out, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = sp + spx[i]
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.relu(self.bns[i](sp))
|
||||||
|
if i == 0:
|
||||||
|
out = sp
|
||||||
|
else:
|
||||||
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||||
|
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||||
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
|
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||||
|
|
||||||
|
self.nums = scale
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
fuse_models = []
|
||||||
|
bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
convs.append(conv3x3(width, width))
|
||||||
|
bns.append(nn.BatchNorm2d(width))
|
||||||
|
for j in range(self.nums - 1):
|
||||||
|
fuse_models.append(AFF(channels=width))
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList(bns)
|
||||||
|
self.fuse_models = nn.ModuleList(fuse_models)
|
||||||
|
self.relu = ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
self.stride = stride
|
||||||
|
self.width = width
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
spx = torch.split(out, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||||
|
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.relu(self.bns[i](sp))
|
||||||
|
if i == 0:
|
||||||
|
out = sp
|
||||||
|
else:
|
||||||
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ERes2NetAug(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
block=BasicBlockERes2Net,
|
||||||
|
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||||
|
num_blocks=[3, 4, 6, 3],
|
||||||
|
m_channels=64,
|
||||||
|
feat_dim=80,
|
||||||
|
embedding_size=192,
|
||||||
|
pooling_func='TSTP',
|
||||||
|
two_emb_layer=False):
|
||||||
|
super(ERes2NetAug, self).__init__()
|
||||||
|
self.in_planes = m_channels
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||||
|
self.two_emb_layer = two_emb_layer
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1,
|
||||||
|
m_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||||
|
self.layer1 = self._make_layer(block,
|
||||||
|
m_channels,
|
||||||
|
num_blocks[0],
|
||||||
|
stride=1)
|
||||||
|
self.layer2 = self._make_layer(block,
|
||||||
|
m_channels * 2,
|
||||||
|
num_blocks[1],
|
||||||
|
stride=2)
|
||||||
|
self.layer3 = self._make_layer(block_fuse,
|
||||||
|
m_channels * 4,
|
||||||
|
num_blocks[2],
|
||||||
|
stride=2)
|
||||||
|
self.layer4 = self._make_layer(block_fuse,
|
||||||
|
m_channels * 8,
|
||||||
|
num_blocks[3],
|
||||||
|
stride=2)
|
||||||
|
|
||||||
|
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
|
||||||
|
bias=False)
|
||||||
|
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
|
||||||
|
bias=False)
|
||||||
|
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2,
|
||||||
|
bias=False)
|
||||||
|
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
||||||
|
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
||||||
|
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
||||||
|
|
||||||
|
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||||
|
self.pool = getattr(pooling_layers, pooling_func)(
|
||||||
|
in_dim=self.stats_dim * block.expansion)
|
||||||
|
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||||
|
embedding_size)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
|
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||||
|
else:
|
||||||
|
self.seg_bn_1 = nn.Identity()
|
||||||
|
self.seg_2 = nn.Identity()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||||
|
|
||||||
|
x = x.unsqueeze_(1)
|
||||||
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
out1 = self.layer1(out)
|
||||||
|
out2 = self.layer2(out1)
|
||||||
|
out1_downsample = self.layer1_downsample(out1)
|
||||||
|
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||||
|
out3 = self.layer3(out2)
|
||||||
|
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||||
|
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||||
|
out4 = self.layer4(out3)
|
||||||
|
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||||
|
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||||
|
stats = self.pool(fuse_out1234)
|
||||||
|
|
||||||
|
embed_a = self.seg_1(stats)
|
||||||
|
if self.two_emb_layer:
|
||||||
|
out = F.relu(embed_a)
|
||||||
|
out = self.seg_bn_1(out)
|
||||||
|
embed_b = self.seg_2(out)
|
||||||
|
return embed_b
|
||||||
|
else:
|
||||||
|
return embed_a
|
||||||
3
funasr/modules/cnn/__init__.py
Normal file
3
funasr/modules/cnn/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .DTDNN import CAMPPlus
|
||||||
|
from .ResNet import ERes2Net
|
||||||
|
from .ResNet_aug import ERes2NetAug
|
||||||
29
funasr/modules/cnn/fusion.py
Normal file
29
funasr/modules/cnn/fusion.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class AFF(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels=64, r=4):
|
||||||
|
super(AFF, self).__init__()
|
||||||
|
inter_channels = int(channels // r)
|
||||||
|
|
||||||
|
self.local_att = nn.Sequential(
|
||||||
|
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(inter_channels),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, ds_y):
|
||||||
|
xa = torch.cat((x, ds_y), dim=1)
|
||||||
|
x_att = self.local_att(xa)
|
||||||
|
x_att = 1.0 + torch.tanh(x_att)
|
||||||
|
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
|
||||||
|
|
||||||
|
return xo
|
||||||
|
|
||||||
254
funasr/modules/cnn/layers.py
Normal file
254
funasr/modules/cnn/layers.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_nonlinear(config_str, channels):
|
||||||
|
nonlinear = nn.Sequential()
|
||||||
|
for name in config_str.split('-'):
|
||||||
|
if name == 'relu':
|
||||||
|
nonlinear.add_module('relu', nn.ReLU(inplace=True))
|
||||||
|
elif name == 'prelu':
|
||||||
|
nonlinear.add_module('prelu', nn.PReLU(channels))
|
||||||
|
elif name == 'batchnorm':
|
||||||
|
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
|
||||||
|
elif name == 'batchnorm_':
|
||||||
|
nonlinear.add_module('batchnorm',
|
||||||
|
nn.BatchNorm1d(channels, affine=False))
|
||||||
|
else:
|
||||||
|
raise ValueError('Unexpected module ({}).'.format(name))
|
||||||
|
return nonlinear
|
||||||
|
|
||||||
|
|
||||||
|
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
|
||||||
|
mean = x.mean(dim=dim)
|
||||||
|
std = x.std(dim=dim, unbiased=unbiased)
|
||||||
|
stats = torch.cat([mean, std], dim=-1)
|
||||||
|
if keepdim:
|
||||||
|
stats = stats.unsqueeze(dim=dim)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class StatsPool(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return statistics_pooling(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TDNNLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=False,
|
||||||
|
config_str='batchnorm-relu'):
|
||||||
|
super(TDNNLayer, self).__init__()
|
||||||
|
if padding < 0:
|
||||||
|
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
||||||
|
kernel_size)
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
self.linear = nn.Conv1d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
self.nonlinear = get_nonlinear(config_str, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear(x)
|
||||||
|
x = self.nonlinear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CAMLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
bn_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
bias,
|
||||||
|
reduction=2):
|
||||||
|
super(CAMLayer, self).__init__()
|
||||||
|
self.linear_local = nn.Conv1d(bn_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.linear_local(x)
|
||||||
|
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
|
||||||
|
context = self.relu(self.linear1(context))
|
||||||
|
m = self.sigmoid(self.linear2(context))
|
||||||
|
return y * m
|
||||||
|
|
||||||
|
def seg_pooling(self, x, seg_len=100, stype='avg'):
|
||||||
|
if stype == 'avg':
|
||||||
|
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
||||||
|
elif stype == 'max':
|
||||||
|
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong segment pooling type.')
|
||||||
|
shape = seg.shape
|
||||||
|
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
|
||||||
|
seg = seg[..., :x.shape[-1]]
|
||||||
|
return seg
|
||||||
|
|
||||||
|
|
||||||
|
class CAMDenseTDNNLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
bn_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
bias=False,
|
||||||
|
config_str='batchnorm-relu',
|
||||||
|
memory_efficient=False):
|
||||||
|
super(CAMDenseTDNNLayer, self).__init__()
|
||||||
|
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
||||||
|
kernel_size)
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
self.memory_efficient = memory_efficient
|
||||||
|
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
||||||
|
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
||||||
|
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
||||||
|
self.cam_layer = CAMLayer(bn_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
def bn_function(self, x):
|
||||||
|
return self.linear1(self.nonlinear1(x))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training and self.memory_efficient:
|
||||||
|
x = cp.checkpoint(self.bn_function, x)
|
||||||
|
else:
|
||||||
|
x = self.bn_function(x)
|
||||||
|
x = self.cam_layer(self.nonlinear2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CAMDenseTDNNBlock(nn.ModuleList):
|
||||||
|
def __init__(self,
|
||||||
|
num_layers,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
bn_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
bias=False,
|
||||||
|
config_str='batchnorm-relu',
|
||||||
|
memory_efficient=False):
|
||||||
|
super(CAMDenseTDNNBlock, self).__init__()
|
||||||
|
for i in range(num_layers):
|
||||||
|
layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
bn_channels=bn_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias,
|
||||||
|
config_str=config_str,
|
||||||
|
memory_efficient=memory_efficient)
|
||||||
|
self.add_module('tdnnd%d' % (i + 1), layer)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self:
|
||||||
|
x = torch.cat([x, layer(x)], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransitLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
bias=True,
|
||||||
|
config_str='batchnorm-relu'):
|
||||||
|
super(TransitLayer, self).__init__()
|
||||||
|
self.nonlinear = get_nonlinear(config_str, in_channels)
|
||||||
|
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.nonlinear(x)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DenseLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
bias=False,
|
||||||
|
config_str='batchnorm-relu'):
|
||||||
|
super(DenseLayer, self).__init__()
|
||||||
|
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
||||||
|
self.nonlinear = get_nonlinear(config_str, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||||
|
else:
|
||||||
|
x = self.linear(x)
|
||||||
|
x = self.nonlinear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicResBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, stride=1):
|
||||||
|
super(BasicResBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_planes,
|
||||||
|
planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(stride, 1),
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv2 = nn.Conv2d(planes,
|
||||||
|
planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=(stride, 1),
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bn2(self.conv2(out))
|
||||||
|
out += self.shortcut(x)
|
||||||
|
out = F.relu(out)
|
||||||
|
return out
|
||||||
@ -1,25 +1,18 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
|
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
|
||||||
"""
|
"""
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint as cp
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
from typing import Union
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import librosa as sf
|
import librosa as sf
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torch.nn.functional as F
|
||||||
import logging
|
|
||||||
from funasr.utils.modelscope_file import File
|
|
||||||
from collections import OrderedDict
|
|
||||||
import torchaudio.compliance.kaldi as Kaldi
|
import torchaudio.compliance.kaldi as Kaldi
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from funasr.utils.modelscope_file import File
|
||||||
|
|
||||||
|
|
||||||
def check_audio_list(audio: list):
|
def check_audio_list(audio: list):
|
||||||
@ -104,230 +97,6 @@ def sv_chunk(vad_segments: list, fs = 16000) -> list:
|
|||||||
return segs
|
return segs
|
||||||
|
|
||||||
|
|
||||||
class BasicResBlock(nn.Module):
|
|
||||||
expansion = 1
|
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1):
|
|
||||||
super(BasicResBlock, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_planes,
|
|
||||||
planes,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=(stride, 1),
|
|
||||||
padding=1,
|
|
||||||
bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
|
|
||||||
self.shortcut = nn.Sequential()
|
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
|
||||||
self.shortcut = nn.Sequential(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_planes,
|
|
||||||
self.expansion * planes,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=(stride, 1),
|
|
||||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.relu(self.bn1(self.conv1(x)))
|
|
||||||
out = self.bn2(self.conv2(out))
|
|
||||||
out += self.shortcut(x)
|
|
||||||
out = F.relu(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class FCM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
block=BasicResBlock,
|
|
||||||
num_blocks=[2, 2],
|
|
||||||
m_channels=32,
|
|
||||||
feat_dim=80):
|
|
||||||
super(FCM, self).__init__()
|
|
||||||
self.in_planes = m_channels
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
|
||||||
|
|
||||||
self.layer1 = self._make_layer(
|
|
||||||
block, m_channels, num_blocks[0], stride=2)
|
|
||||||
self.layer2 = self._make_layer(
|
|
||||||
block, m_channels, num_blocks[0], stride=2)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
m_channels,
|
|
||||||
m_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=(2, 1),
|
|
||||||
padding=1,
|
|
||||||
bias=False)
|
|
||||||
self.bn2 = nn.BatchNorm2d(m_channels)
|
|
||||||
self.out_channels = m_channels * (feat_dim // 8)
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, num_blocks, stride):
|
|
||||||
strides = [stride] + [1] * (num_blocks - 1)
|
|
||||||
layers = []
|
|
||||||
for stride in strides:
|
|
||||||
layers.append(block(self.in_planes, planes, stride))
|
|
||||||
self.in_planes = planes * block.expansion
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.unsqueeze(1)
|
|
||||||
out = F.relu(self.bn1(self.conv1(x)))
|
|
||||||
out = self.layer1(out)
|
|
||||||
out = self.layer2(out)
|
|
||||||
out = F.relu(self.bn2(self.conv2(out)))
|
|
||||||
|
|
||||||
shape = out.shape
|
|
||||||
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CAMPPlus(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
feat_dim=80,
|
|
||||||
embedding_size=192,
|
|
||||||
growth_rate=32,
|
|
||||||
bn_size=4,
|
|
||||||
init_channels=128,
|
|
||||||
config_str='batchnorm-relu',
|
|
||||||
memory_efficient=True,
|
|
||||||
output_level='segment'):
|
|
||||||
super(CAMPPlus, self).__init__()
|
|
||||||
|
|
||||||
self.head = FCM(feat_dim=feat_dim)
|
|
||||||
channels = self.head.out_channels
|
|
||||||
self.output_level = output_level
|
|
||||||
|
|
||||||
self.xvector = nn.Sequential(
|
|
||||||
OrderedDict([
|
|
||||||
('tdnn',
|
|
||||||
TDNNLayer(
|
|
||||||
channels,
|
|
||||||
init_channels,
|
|
||||||
5,
|
|
||||||
stride=2,
|
|
||||||
dilation=1,
|
|
||||||
padding=-1,
|
|
||||||
config_str=config_str)),
|
|
||||||
]))
|
|
||||||
channels = init_channels
|
|
||||||
for i, (num_layers, kernel_size, dilation) in enumerate(
|
|
||||||
zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
|
||||||
block = CAMDenseTDNNBlock(
|
|
||||||
num_layers=num_layers,
|
|
||||||
in_channels=channels,
|
|
||||||
out_channels=growth_rate,
|
|
||||||
bn_channels=bn_size * growth_rate,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
dilation=dilation,
|
|
||||||
config_str=config_str,
|
|
||||||
memory_efficient=memory_efficient)
|
|
||||||
self.xvector.add_module('block%d' % (i + 1), block)
|
|
||||||
channels = channels + num_layers * growth_rate
|
|
||||||
self.xvector.add_module(
|
|
||||||
'transit%d' % (i + 1),
|
|
||||||
TransitLayer(
|
|
||||||
channels, channels // 2, bias=False,
|
|
||||||
config_str=config_str))
|
|
||||||
channels //= 2
|
|
||||||
|
|
||||||
self.xvector.add_module('out_nonlinear',
|
|
||||||
get_nonlinear(config_str, channels))
|
|
||||||
|
|
||||||
if self.output_level == 'segment':
|
|
||||||
self.xvector.add_module('stats', StatsPool())
|
|
||||||
self.xvector.add_module(
|
|
||||||
'dense',
|
|
||||||
DenseLayer(
|
|
||||||
channels * 2, embedding_size, config_str='batchnorm_'))
|
|
||||||
else:
|
|
||||||
assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
|
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
||||||
nn.init.kaiming_normal_(m.weight.data)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
|
||||||
x = self.head(x)
|
|
||||||
x = self.xvector(x)
|
|
||||||
if self.output_level == 'frame':
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def get_nonlinear(config_str, channels):
|
|
||||||
nonlinear = nn.Sequential()
|
|
||||||
for name in config_str.split('-'):
|
|
||||||
if name == 'relu':
|
|
||||||
nonlinear.add_module('relu', nn.ReLU(inplace=True))
|
|
||||||
elif name == 'prelu':
|
|
||||||
nonlinear.add_module('prelu', nn.PReLU(channels))
|
|
||||||
elif name == 'batchnorm':
|
|
||||||
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
|
|
||||||
elif name == 'batchnorm_':
|
|
||||||
nonlinear.add_module('batchnorm',
|
|
||||||
nn.BatchNorm1d(channels, affine=False))
|
|
||||||
else:
|
|
||||||
raise ValueError('Unexpected module ({}).'.format(name))
|
|
||||||
return nonlinear
|
|
||||||
|
|
||||||
|
|
||||||
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
|
|
||||||
mean = x.mean(dim=dim)
|
|
||||||
std = x.std(dim=dim, unbiased=unbiased)
|
|
||||||
stats = torch.cat([mean, std], dim=-1)
|
|
||||||
if keepdim:
|
|
||||||
stats = stats.unsqueeze(dim=dim)
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
class StatsPool(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return statistics_pooling(x)
|
|
||||||
|
|
||||||
|
|
||||||
class TDNNLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
dilation=1,
|
|
||||||
bias=False,
|
|
||||||
config_str='batchnorm-relu'):
|
|
||||||
super(TDNNLayer, self).__init__()
|
|
||||||
if padding < 0:
|
|
||||||
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
|
||||||
kernel_size)
|
|
||||||
padding = (kernel_size - 1) // 2 * dilation
|
|
||||||
self.linear = nn.Conv1d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
bias=bias)
|
|
||||||
self.nonlinear = get_nonlinear(config_str, out_channels)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear(x)
|
|
||||||
x = self.nonlinear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def extract_feature(audio):
|
def extract_feature(audio):
|
||||||
features = []
|
features = []
|
||||||
for au in audio:
|
for au in audio:
|
||||||
@ -387,116 +156,6 @@ class CAMLayer(nn.Module):
|
|||||||
return seg
|
return seg
|
||||||
|
|
||||||
|
|
||||||
class CAMDenseTDNNLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bn_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
bias=False,
|
|
||||||
config_str='batchnorm-relu',
|
|
||||||
memory_efficient=False):
|
|
||||||
super(CAMDenseTDNNLayer, self).__init__()
|
|
||||||
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
|
||||||
kernel_size)
|
|
||||||
padding = (kernel_size - 1) // 2 * dilation
|
|
||||||
self.memory_efficient = memory_efficient
|
|
||||||
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
|
||||||
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
|
||||||
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
|
||||||
self.cam_layer = CAMLayer(
|
|
||||||
bn_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
bias=bias)
|
|
||||||
|
|
||||||
def bn_function(self, x):
|
|
||||||
return self.linear1(self.nonlinear1(x))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.training and self.memory_efficient:
|
|
||||||
x = cp.checkpoint(self.bn_function, x)
|
|
||||||
else:
|
|
||||||
x = self.bn_function(x)
|
|
||||||
x = self.cam_layer(self.nonlinear2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class CAMDenseTDNNBlock(nn.ModuleList):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
num_layers,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bn_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
bias=False,
|
|
||||||
config_str='batchnorm-relu',
|
|
||||||
memory_efficient=False):
|
|
||||||
super(CAMDenseTDNNBlock, self).__init__()
|
|
||||||
for i in range(num_layers):
|
|
||||||
layer = CAMDenseTDNNLayer(
|
|
||||||
in_channels=in_channels + i * out_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
bn_channels=bn_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
dilation=dilation,
|
|
||||||
bias=bias,
|
|
||||||
config_str=config_str,
|
|
||||||
memory_efficient=memory_efficient)
|
|
||||||
self.add_module('tdnnd%d' % (i + 1), layer)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for layer in self:
|
|
||||||
x = torch.cat([x, layer(x)], dim=1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransitLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bias=True,
|
|
||||||
config_str='batchnorm-relu'):
|
|
||||||
super(TransitLayer, self).__init__()
|
|
||||||
self.nonlinear = get_nonlinear(config_str, in_channels)
|
|
||||||
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.nonlinear(x)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DenseLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bias=False,
|
|
||||||
config_str='batchnorm-relu'):
|
|
||||||
super(DenseLayer, self).__init__()
|
|
||||||
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
|
||||||
self.nonlinear = get_nonlinear(config_str, out_channels)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if len(x.shape) == 2:
|
|
||||||
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
|
||||||
else:
|
|
||||||
x = self.linear(x)
|
|
||||||
x = self.nonlinear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def postprocess(segments: list, vad_segments: list,
|
def postprocess(segments: list, vad_segments: list,
|
||||||
labels: np.ndarray, embeddings: np.ndarray) -> list:
|
labels: np.ndarray, embeddings: np.ndarray) -> list:
|
||||||
assert len(segments) == len(labels)
|
assert len(segments) == len(labels)
|
||||||
@ -592,300 +251,3 @@ def distribute_spk(sentence_list, sd_time_list):
|
|||||||
d['spk'] = sentence_spk
|
d['spk'] = sentence_spk
|
||||||
sd_sentence_list.append(d)
|
sd_sentence_list.append(d)
|
||||||
return sd_sentence_list
|
return sd_sentence_list
|
||||||
|
|
||||||
|
|
||||||
class AFF(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels=64, r=4):
|
|
||||||
super(AFF, self).__init__()
|
|
||||||
inter_channels = int(channels // r)
|
|
||||||
|
|
||||||
self.local_att = nn.Sequential(
|
|
||||||
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(inter_channels),
|
|
||||||
nn.SiLU(inplace=True),
|
|
||||||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, ds_y):
|
|
||||||
xa = torch.cat((x, ds_y), dim=1)
|
|
||||||
x_att = self.local_att(xa)
|
|
||||||
x_att = 1.0 + torch.tanh(x_att)
|
|
||||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
|
||||||
|
|
||||||
return xo
|
|
||||||
|
|
||||||
|
|
||||||
class TSTP(nn.Module):
|
|
||||||
"""
|
|
||||||
Temporal statistics pooling, concatenate mean and std, which is used in
|
|
||||||
x-vector
|
|
||||||
Comment: simple concatenation can not make full use of both statistics
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(TSTP, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# The last dimension is the temporal axis
|
|
||||||
pooling_mean = x.mean(dim=-1)
|
|
||||||
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
|
||||||
pooling_mean = pooling_mean.flatten(start_dim=1)
|
|
||||||
pooling_std = pooling_std.flatten(start_dim=1)
|
|
||||||
|
|
||||||
stats = torch.cat((pooling_mean, pooling_std), 1)
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
class ReLU(nn.Hardtanh):
|
|
||||||
|
|
||||||
def __init__(self, inplace=False):
|
|
||||||
super(ReLU, self).__init__(0, 20, inplace)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
inplace_str = 'inplace' if self.inplace else ''
|
|
||||||
return self.__class__.__name__ + ' (' \
|
|
||||||
+ inplace_str + ')'
|
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
|
||||||
"1x1 convolution without padding"
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
|
||||||
padding=0, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1):
|
|
||||||
"3x3 convolution with padding"
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
||||||
padding=1, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net(nn.Module):
|
|
||||||
expansion = 4
|
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
|
||||||
super(BasicBlockERes2Net, self).__init__()
|
|
||||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
|
||||||
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
|
||||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
|
||||||
self.nums = scale
|
|
||||||
|
|
||||||
convs = []
|
|
||||||
bns = []
|
|
||||||
for i in range(self.nums):
|
|
||||||
convs.append(conv3x3(width, width))
|
|
||||||
bns.append(nn.BatchNorm2d(width))
|
|
||||||
self.convs = nn.ModuleList(convs)
|
|
||||||
self.bns = nn.ModuleList(bns)
|
|
||||||
self.relu = ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
|
||||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
|
||||||
self.shortcut = nn.Sequential()
|
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
|
||||||
self.shortcut = nn.Sequential(
|
|
||||||
nn.Conv2d(in_planes,
|
|
||||||
self.expansion * planes,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
bias=False),
|
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
|
||||||
self.stride = stride
|
|
||||||
self.width = width
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
spx = torch.split(out, self.width, 1)
|
|
||||||
for i in range(self.nums):
|
|
||||||
if i == 0:
|
|
||||||
sp = spx[i]
|
|
||||||
else:
|
|
||||||
sp = sp + spx[i]
|
|
||||||
sp = self.convs[i](sp)
|
|
||||||
sp = self.relu(self.bns[i](sp))
|
|
||||||
if i == 0:
|
|
||||||
out = sp
|
|
||||||
else:
|
|
||||||
out = torch.cat((out, sp), 1)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
residual = self.shortcut(x)
|
|
||||||
out += residual
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|
||||||
expansion = 4
|
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
|
||||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
|
||||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
|
||||||
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
|
||||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
|
||||||
|
|
||||||
self.nums = scale
|
|
||||||
|
|
||||||
convs = []
|
|
||||||
fuse_models = []
|
|
||||||
bns = []
|
|
||||||
for i in range(self.nums):
|
|
||||||
convs.append(conv3x3(width, width))
|
|
||||||
bns.append(nn.BatchNorm2d(width))
|
|
||||||
for j in range(self.nums - 1):
|
|
||||||
fuse_models.append(AFF(channels=width))
|
|
||||||
|
|
||||||
self.convs = nn.ModuleList(convs)
|
|
||||||
self.bns = nn.ModuleList(bns)
|
|
||||||
self.fuse_models = nn.ModuleList(fuse_models)
|
|
||||||
self.relu = ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
|
||||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
|
||||||
self.shortcut = nn.Sequential()
|
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
|
||||||
self.shortcut = nn.Sequential(
|
|
||||||
nn.Conv2d(in_planes,
|
|
||||||
self.expansion * planes,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
bias=False),
|
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
|
||||||
self.stride = stride
|
|
||||||
self.width = width
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
spx = torch.split(out, self.width, 1)
|
|
||||||
for i in range(self.nums):
|
|
||||||
if i == 0:
|
|
||||||
sp = spx[i]
|
|
||||||
else:
|
|
||||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
|
||||||
|
|
||||||
sp = self.convs[i](sp)
|
|
||||||
sp = self.relu(self.bns[i](sp))
|
|
||||||
if i == 0:
|
|
||||||
out = sp
|
|
||||||
else:
|
|
||||||
out = torch.cat((out, sp), 1)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
residual = self.shortcut(x)
|
|
||||||
out += residual
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ERes2Net(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
block=BasicBlockERes2Net,
|
|
||||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
|
||||||
num_blocks=[3, 4, 6, 3],
|
|
||||||
m_channels=64,
|
|
||||||
feat_dim=80,
|
|
||||||
embedding_size=192,
|
|
||||||
pooling_func='TSTP',
|
|
||||||
two_emb_layer=False):
|
|
||||||
super(ERes2Net, self).__init__()
|
|
||||||
self.in_planes = m_channels
|
|
||||||
self.feat_dim = feat_dim
|
|
||||||
self.embedding_size = embedding_size
|
|
||||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
|
||||||
self.two_emb_layer = two_emb_layer
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1,
|
|
||||||
m_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
|
||||||
self.layer1 = self._make_layer(block,
|
|
||||||
m_channels,
|
|
||||||
num_blocks[0],
|
|
||||||
stride=1)
|
|
||||||
self.layer2 = self._make_layer(block,
|
|
||||||
m_channels * 2,
|
|
||||||
num_blocks[1],
|
|
||||||
stride=2)
|
|
||||||
self.layer3 = self._make_layer(block_fuse,
|
|
||||||
m_channels * 4,
|
|
||||||
num_blocks[2],
|
|
||||||
stride=2)
|
|
||||||
self.layer4 = self._make_layer(block_fuse,
|
|
||||||
m_channels * 8,
|
|
||||||
num_blocks[3],
|
|
||||||
stride=2)
|
|
||||||
|
|
||||||
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
|
|
||||||
bias=False)
|
|
||||||
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
|
|
||||||
bias=False)
|
|
||||||
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2,
|
|
||||||
bias=False)
|
|
||||||
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
|
||||||
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
|
||||||
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
|
||||||
|
|
||||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
|
||||||
self.pool = TSTP(in_dim=self.stats_dim * block.expansion)
|
|
||||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
|
||||||
embedding_size)
|
|
||||||
if self.two_emb_layer:
|
|
||||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
|
||||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
|
||||||
else:
|
|
||||||
self.seg_bn_1 = nn.Identity()
|
|
||||||
self.seg_2 = nn.Identity()
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, num_blocks, stride):
|
|
||||||
strides = [stride] + [1] * (num_blocks - 1)
|
|
||||||
layers = []
|
|
||||||
for stride in strides:
|
|
||||||
layers.append(block(self.in_planes, planes, stride))
|
|
||||||
self.in_planes = planes * block.expansion
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
|
||||||
|
|
||||||
x = x.unsqueeze_(1)
|
|
||||||
out = F.relu(self.bn1(self.conv1(x)))
|
|
||||||
out1 = self.layer1(out)
|
|
||||||
out2 = self.layer2(out1)
|
|
||||||
out1_downsample = self.layer1_downsample(out1)
|
|
||||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
|
||||||
out3 = self.layer3(out2)
|
|
||||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
|
||||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
|
||||||
out4 = self.layer4(out3)
|
|
||||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
|
||||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
|
||||||
stats = self.pool(fuse_out1234)
|
|
||||||
|
|
||||||
embed_a = self.seg_1(stats)
|
|
||||||
if self.two_emb_layer:
|
|
||||||
out = F.relu(embed_a)
|
|
||||||
out = self.seg_bn_1(out)
|
|
||||||
embed_b = self.seg_2(out)
|
|
||||||
return embed_b
|
|
||||||
else:
|
|
||||||
return embed_a
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user