From 73613cefc97bd43699d10b8d162c69b2c4544ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=9C=E9=9B=A8=E9=A3=98=E9=9B=B6?= Date: Mon, 4 Dec 2023 21:41:07 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E8=A7=92?= =?UTF-8?q?=E8=89=B2=E8=AF=AD=E9=9F=B3=E8=AF=86=E5=88=AB=E5=AF=B9ERes2Net?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E6=94=AF=E6=8C=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- funasr/bin/asr_inference_launch.py | 30 ++- funasr/utils/speaker_utils.py | 300 ++++++++++++++++++++++++++++- 2 files changed, 320 insertions(+), 10 deletions(-) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index f61c0859d..402a91197 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -48,13 +48,13 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils.vad_utils import slice_padding_fbank -from funasr.utils.speaker_utils import (check_audio_list, - sv_preprocess, - sv_chunk, - CAMPPlus, - extract_feature, +from funasr.utils.speaker_utils import (check_audio_list, + sv_preprocess, + sv_chunk, + CAMPPlus, + extract_feature, postprocess, - distribute_spk) + distribute_spk, ERes2Net) from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.utils.cluster_backend import ClusterBackend from funasr.utils.modelscope_utils import get_cache_dir @@ -819,6 +819,10 @@ def inference_paraformer_vad_speaker( ) sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin") + if not os.path.exists(sv_model_file): + sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt") + if not os.path.exists(sv_model_file): + raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file)) if param_dict is not None: hotword_list_or_file = param_dict.get('hotword') @@ -944,8 +948,14 @@ def inference_paraformer_vad_speaker( ##### speaker_verification ##### ################################## # load sv model - sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) - sv_model = CAMPPlus() + sv_model_dict = torch.load(sv_model_file) + print(f'load sv model params: {sv_model_file}') + if os.path.basename(sv_model_file) == "campplus_cn_common.bin": + sv_model = CAMPPlus() + else: + sv_model = ERes2Net() + if ngpu > 0: + sv_model.cuda() sv_model.load_state_dict(sv_model_dict) sv_model.eval() cb_model = ClusterBackend() @@ -969,9 +979,11 @@ def inference_paraformer_vad_speaker( embs = [] for x in wavs: x = extract_feature([x]) + if ngpu > 0: + x = x.cuda() embs.append(sv_model(x)) embs = torch.cat(embs) - embeddings.append(embs.detach().numpy()) + embeddings.append(embs.cpu().detach().numpy()) embeddings = np.concatenate(embeddings) labels = cb_model(embeddings) sv_output = postprocess(segments, vad_segments, labels, embeddings) diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py index edaf58b75..df3eca7d8 100644 --- a/funasr/utils/speaker_utils.py +++ b/funasr/utils/speaker_utils.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN """ +import math import torch import torch.nn.functional as F @@ -590,4 +591,301 @@ def distribute_spk(sentence_list, sd_time_list): sentence_spk = spk d['spk'] = sentence_spk sd_sentence_list.append(d) - return sd_sentence_list \ No newline at end of file + return sd_sentence_list + + +class AFF(nn.Module): + + def __init__(self, channels=64, r=4): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + self.local_att = nn.Sequential( + nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.SiLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + def forward(self, x, ds_y): + xa = torch.cat((x, ds_y), dim=1) + x_att = self.local_att(xa) + x_att = 1.0 + torch.tanh(x_att) + xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att) + + return xo + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, **kwargs): + super(TSTP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + "1x1 convolution without padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block_fuse, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block_fuse, + m_channels * 8, + num_blocks[3], + stride=2) + + self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, + bias=False) + self.fuse_mode12 = AFF(channels=m_channels * 8) + self.fuse_mode123 = AFF(channels=m_channels * 16) + self.fuse_mode1234 = AFF(channels=m_channels * 32) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = TSTP(in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a From 18b1449d1ff06c469e54190508c4f6be05c73d85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=9C=E9=9B=A8=E9=A3=98=E9=9B=B6?= Date: Tue, 5 Dec 2023 22:04:14 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=88=86=E8=A7=92=E8=89=B2=E8=AF=AD?= =?UTF-8?q?=E9=9F=B3=E8=AF=86=E5=88=AB=E6=94=AF=E6=8C=81=E6=9B=B4=E5=A4=9A?= =?UTF-8?q?=E7=9A=84=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- funasr/bin/asr_inference_launch.py | 25 +- funasr/models/pooling/pooling_layers.py | 108 ++++ funasr/modules/cnn/DTDNN.py | 124 +++++ funasr/modules/cnn/ResNet.py | 420 +++++++++++++++ funasr/modules/cnn/ResNet_aug.py | 273 ++++++++++ funasr/modules/cnn/__init__.py | 3 + funasr/modules/cnn/fusion.py | 29 ++ funasr/modules/cnn/layers.py | 254 +++++++++ funasr/utils/speaker_utils.py | 650 +----------------------- 9 files changed, 1230 insertions(+), 656 deletions(-) create mode 100644 funasr/models/pooling/pooling_layers.py create mode 100644 funasr/modules/cnn/DTDNN.py create mode 100644 funasr/modules/cnn/ResNet.py create mode 100644 funasr/modules/cnn/ResNet_aug.py create mode 100644 funasr/modules/cnn/__init__.py create mode 100644 funasr/modules/cnn/fusion.py create mode 100644 funasr/modules/cnn/layers.py diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 402a91197..59e61ee64 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -51,10 +51,10 @@ from funasr.utils.vad_utils import slice_padding_fbank from funasr.utils.speaker_utils import (check_audio_list, sv_preprocess, sv_chunk, - CAMPPlus, extract_feature, postprocess, - distribute_spk, ERes2Net) + distribute_spk) +import funasr.modules.cnn as sv_module from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.utils.cluster_backend import ClusterBackend from funasr.utils.modelscope_utils import get_cache_dir @@ -818,11 +818,15 @@ def inference_paraformer_vad_speaker( format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) - sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin") - if not os.path.exists(sv_model_file): - sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt") - if not os.path.exists(sv_model_file): - raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file)) + sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml") + if not os.path.exists(sv_model_config_path): + sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}} + else: + with open(sv_model_config_path, 'r') as f: + sv_model_config = yaml.load(f, Loader=yaml.FullLoader) + if sv_model_config['models_config'] is None: + sv_model_config['models_config'] = {} + sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file']) if param_dict is not None: hotword_list_or_file = param_dict.get('hotword') @@ -949,14 +953,11 @@ def inference_paraformer_vad_speaker( ################################## # load sv model sv_model_dict = torch.load(sv_model_file) - print(f'load sv model params: {sv_model_file}') - if os.path.basename(sv_model_file) == "campplus_cn_common.bin": - sv_model = CAMPPlus() - else: - sv_model = ERes2Net() + sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) if ngpu > 0: sv_model.cuda() sv_model.load_state_dict(sv_model_dict) + print(f'load sv model params: {sv_model_file}') sv_model.eval() cb_model = ClusterBackend() vad_segments = [] diff --git a/funasr/models/pooling/pooling_layers.py b/funasr/models/pooling/pooling_layers.py new file mode 100644 index 000000000..0aa10fefb --- /dev/null +++ b/funasr/models/pooling/pooling_layers.py @@ -0,0 +1,108 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker.""" + +import torch +import torch.nn as nn + + +class TAP(nn.Module): + """ + Temporal average pooling, only first-order mean is considered + """ + + def __init__(self, **kwargs): + super(TAP, self).__init__() + + def forward(self, x): + pooling_mean = x.mean(dim=-1) + # To be compatable with 2D input + pooling_mean = pooling_mean.flatten(start_dim=1) + return pooling_mean + + +class TSDP(nn.Module): + """ + Temporal standard deviation pooling, only second-order std is considered + """ + + def __init__(self, **kwargs): + super(TSDP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_std = pooling_std.flatten(start_dim=1) + return pooling_std + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, **kwargs): + super(TSTP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + +class ASTP(nn.Module): + """ Attentive statistics pooling: Channel- and context-dependent + statistics pooling, first used in ECAPA_TDNN. + """ + + def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): + super(ASTP, self).__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't + # need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, + kernel_size=1) # equals V and k in the paper + + def forward(self, x): + """ + x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) + or a 4-dimensional tensor in resnet architecture (B,C,F,T) + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(x.shape) == 4: + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) + assert len(x.shape) == 3 + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! ReLU may be hard to converge. + alpha = torch.tanh( + self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 + std = torch.sqrt(var.clamp(min=1e-10)) + return torch.cat([mean, std], dim=1) diff --git a/funasr/modules/cnn/DTDNN.py b/funasr/modules/cnn/DTDNN.py new file mode 100644 index 000000000..3de0b1e45 --- /dev/null +++ b/funasr/modules/cnn/DTDNN.py @@ -0,0 +1,124 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from collections import OrderedDict + +import torch.nn.functional as F +from torch import nn + +from funasr.modules.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \ + BasicResBlock, get_nonlinear + + +class FCM(nn.Module): + def __init__(self, + block=BasicResBlock, + num_blocks=[2, 2], + m_channels=32, + feat_dim=80): + super(FCM, self).__init__() + self.in_planes = m_channels + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + + self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(m_channels) + self.out_channels = m_channels * (feat_dim // 8) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.unsqueeze(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = F.relu(self.bn2(self.conv2(out))) + + shape = out.shape + out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) + return out + + +class CAMPPlus(nn.Module): + def __init__(self, + feat_dim=80, + embedding_size=192, + growth_rate=32, + bn_size=4, + init_channels=128, + config_str='batchnorm-relu', + memory_efficient=True, + output_level='segment'): + super(CAMPPlus, self).__init__() + + self.head = FCM(feat_dim=feat_dim) + channels = self.head.out_channels + self.output_level = output_level + + self.xvector = nn.Sequential( + OrderedDict([ + + ('tdnn', + TDNNLayer(channels, + init_channels, + 5, + stride=2, + dilation=1, + padding=-1, + config_str=config_str)), + ])) + channels = init_channels + for i, (num_layers, kernel_size, + dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))): + block = CAMDenseTDNNBlock(num_layers=num_layers, + in_channels=channels, + out_channels=growth_rate, + bn_channels=bn_size * growth_rate, + kernel_size=kernel_size, + dilation=dilation, + config_str=config_str, + memory_efficient=memory_efficient) + self.xvector.add_module('block%d' % (i + 1), block) + channels = channels + num_layers * growth_rate + self.xvector.add_module( + 'transit%d' % (i + 1), + TransitLayer(channels, + channels // 2, + bias=False, + config_str=config_str)) + channels //= 2 + + self.xvector.add_module( + 'out_nonlinear', get_nonlinear(config_str, channels)) + + if self.output_level == 'segment': + self.xvector.add_module('stats', StatsPool()) + self.xvector.add_module( + 'dense', + DenseLayer( + channels * 2, embedding_size, config_str='batchnorm_')) + else: + assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. ' + + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = self.head(x) + x = self.xvector(x) + if self.output_level == 'frame': + x = x.transpose(1, 2) + return x diff --git a/funasr/modules/cnn/ResNet.py b/funasr/modules/cnn/ResNet.py new file mode 100644 index 000000000..c3bf13cf6 --- /dev/null +++ b/funasr/modules/cnn/ResNet.py @@ -0,0 +1,420 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. + ERes2Net incorporates both local and global feature fusion techniques to improve the performance. + The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. + The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. + ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better + recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import funasr.models.pooling.pooling_layers as pooling_layers +from funasr.modules.cnn.fusion import AFF + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + "1x1 convolution without padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block_fuse, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block_fuse, + m_channels * 8, + num_blocks[3], + stride=2) + + # Downsampling module for each layer + self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, + bias=False) + self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, + bias=False) + + # Bottom-up fusion module + self.fuse_mode12 = AFF(channels=m_channels * 4) + self.fuse_mode123 = AFF(channels=m_channels * 8) + self.fuse_mode1234 = AFF(channels=m_channels * 16) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + +class BasicBlockRes2Net(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockRes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale - 1 + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = torch.cat((out, spx[self.nums]), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class Res2Net(nn.Module): + def __init__(self, + block=BasicBlockRes2Net, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(Res2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block, + m_channels * 8, + num_blocks[3], + stride=2) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + stats = self.pool(out) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a diff --git a/funasr/modules/cnn/ResNet_aug.py b/funasr/modules/cnn/ResNet_aug.py new file mode 100644 index 000000000..d2d845d93 --- /dev/null +++ b/funasr/modules/cnn/ResNet_aug.py @@ -0,0 +1,273 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. + ERes2Net incorporates both local and global feature fusion techniques to improve the performance. + The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. + The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. + ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better + recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import funasr.models.pooling.pooling_layers as pooling_layers +from funasr.modules.cnn.fusion import AFF + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + "1x1 convolution without padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2NetAug(nn.Module): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2NetAug, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block_fuse, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block_fuse, + m_channels * 8, + num_blocks[3], + stride=2) + + self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, + bias=False) + self.fuse_mode12 = AFF(channels=m_channels * 8) + self.fuse_mode123 = AFF(channels=m_channels * 16) + self.fuse_mode1234 = AFF(channels=m_channels * 32) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a diff --git a/funasr/modules/cnn/__init__.py b/funasr/modules/cnn/__init__.py new file mode 100644 index 000000000..d434c988e --- /dev/null +++ b/funasr/modules/cnn/__init__.py @@ -0,0 +1,3 @@ +from .DTDNN import CAMPPlus +from .ResNet import ERes2Net +from .ResNet_aug import ERes2NetAug diff --git a/funasr/modules/cnn/fusion.py b/funasr/modules/cnn/fusion.py new file mode 100644 index 000000000..2aff7a721 --- /dev/null +++ b/funasr/modules/cnn/fusion.py @@ -0,0 +1,29 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn as nn + + +class AFF(nn.Module): + + def __init__(self, channels=64, r=4): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + self.local_att = nn.Sequential( + nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.SiLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + def forward(self, x, ds_y): + xa = torch.cat((x, ds_y), dim=1) + x_att = self.local_att(xa) + x_att = 1.0 + torch.tanh(x_att) + xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att) + + return xo + diff --git a/funasr/modules/cnn/layers.py b/funasr/modules/cnn/layers.py new file mode 100644 index 000000000..0475612a9 --- /dev/null +++ b/funasr/modules/cnn/layers.py @@ -0,0 +1,254 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch import nn + + +def get_nonlinear(config_str, channels): + nonlinear = nn.Sequential() + for name in config_str.split('-'): + if name == 'relu': + nonlinear.add_module('relu', nn.ReLU(inplace=True)) + elif name == 'prelu': + nonlinear.add_module('prelu', nn.PReLU(channels)) + elif name == 'batchnorm': + nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) + elif name == 'batchnorm_': + nonlinear.add_module('batchnorm', + nn.BatchNorm1d(channels, affine=False)) + else: + raise ValueError('Unexpected module ({}).'.format(name)) + return nonlinear + + +def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): + mean = x.mean(dim=dim) + std = x.std(dim=dim, unbiased=unbiased) + stats = torch.cat([mean, std], dim=-1) + if keepdim: + stats = stats.unsqueeze(dim=dim) + return stats + + +class StatsPool(nn.Module): + def forward(self, x): + return statistics_pooling(x) + + +class TDNNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=False, + config_str='batchnorm-relu'): + super(TDNNLayer, self).__init__() + if padding < 0: + assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( + kernel_size) + padding = (kernel_size - 1) // 2 * dilation + self.linear = nn.Conv1d(in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + x = self.linear(x) + x = self.nonlinear(x) + return x + + +class CAMLayer(nn.Module): + def __init__(self, + bn_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + bias, + reduction=2): + super(CAMLayer, self).__init__() + self.linear_local = nn.Conv1d(bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + y = self.linear_local(x) + context = x.mean(-1, keepdim=True) + self.seg_pooling(x) + context = self.relu(self.linear1(context)) + m = self.sigmoid(self.linear2(context)) + return y * m + + def seg_pooling(self, x, seg_len=100, stype='avg'): + if stype == 'avg': + seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + elif stype == 'max': + seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + else: + raise ValueError('Wrong segment pooling type.') + shape = seg.shape + seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) + seg = seg[..., :x.shape[-1]] + return seg + + +class CAMDenseTDNNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str='batchnorm-relu', + memory_efficient=False): + super(CAMDenseTDNNLayer, self).__init__() + assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( + kernel_size) + padding = (kernel_size - 1) // 2 * dilation + self.memory_efficient = memory_efficient + self.nonlinear1 = get_nonlinear(config_str, in_channels) + self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) + self.nonlinear2 = get_nonlinear(config_str, bn_channels) + self.cam_layer = CAMLayer(bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + def bn_function(self, x): + return self.linear1(self.nonlinear1(x)) + + def forward(self, x): + if self.training and self.memory_efficient: + x = cp.checkpoint(self.bn_function, x) + else: + x = self.bn_function(x) + x = self.cam_layer(self.nonlinear2(x)) + return x + + +class CAMDenseTDNNBlock(nn.ModuleList): + def __init__(self, + num_layers, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str='batchnorm-relu', + memory_efficient=False): + super(CAMDenseTDNNBlock, self).__init__() + for i in range(num_layers): + layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels, + out_channels=out_channels, + bn_channels=bn_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + config_str=config_str, + memory_efficient=memory_efficient) + self.add_module('tdnnd%d' % (i + 1), layer) + + def forward(self, x): + for layer in self: + x = torch.cat([x, layer(x)], dim=1) + return x + + +class TransitLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=True, + config_str='batchnorm-relu'): + super(TransitLayer, self).__init__() + self.nonlinear = get_nonlinear(config_str, in_channels) + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + x = self.nonlinear(x) + x = self.linear(x) + return x + + +class DenseLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=False, + config_str='batchnorm-relu'): + super(DenseLayer, self).__init__() + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + if len(x.shape) == 2: + x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) + else: + x = self.linear(x) + x = self.nonlinear(x) + return x + + +class BasicResBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicResBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=(stride, 1), + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=(stride, 1), + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py index df3eca7d8..a3eebf9d9 100644 --- a/funasr/utils/speaker_utils.py +++ b/funasr/utils/speaker_utils.py @@ -1,25 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN """ -import math - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint as cp -from torch import nn import io -import os -from typing import Any, Dict, List, Union +from typing import Union -import numpy as np import librosa as sf +import numpy as np import torch -import torchaudio -import logging -from funasr.utils.modelscope_file import File -from collections import OrderedDict +import torch.nn.functional as F import torchaudio.compliance.kaldi as Kaldi +from torch import nn + +from funasr.utils.modelscope_file import File def check_audio_list(audio: list): @@ -104,230 +97,6 @@ def sv_chunk(vad_segments: list, fs = 16000) -> list: return segs -class BasicResBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicResBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, - planes, - kernel_size=3, - stride=(stride, 1), - padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=(stride, 1), - bias=False), nn.BatchNorm2d(self.expansion * planes)) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class FCM(nn.Module): - - def __init__(self, - block=BasicResBlock, - num_blocks=[2, 2], - m_channels=32, - feat_dim=80): - super(FCM, self).__init__() - self.in_planes = m_channels - self.conv1 = nn.Conv2d( - 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(m_channels) - - self.layer1 = self._make_layer( - block, m_channels, num_blocks[0], stride=2) - self.layer2 = self._make_layer( - block, m_channels, num_blocks[0], stride=2) - - self.conv2 = nn.Conv2d( - m_channels, - m_channels, - kernel_size=3, - stride=(2, 1), - padding=1, - bias=False) - self.bn2 = nn.BatchNorm2d(m_channels) - self.out_channels = m_channels * (feat_dim // 8) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - x = x.unsqueeze(1) - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = F.relu(self.bn2(self.conv2(out))) - - shape = out.shape - out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) - return out - - -class CAMPPlus(nn.Module): - - def __init__(self, - feat_dim=80, - embedding_size=192, - growth_rate=32, - bn_size=4, - init_channels=128, - config_str='batchnorm-relu', - memory_efficient=True, - output_level='segment'): - super(CAMPPlus, self).__init__() - - self.head = FCM(feat_dim=feat_dim) - channels = self.head.out_channels - self.output_level = output_level - - self.xvector = nn.Sequential( - OrderedDict([ - ('tdnn', - TDNNLayer( - channels, - init_channels, - 5, - stride=2, - dilation=1, - padding=-1, - config_str=config_str)), - ])) - channels = init_channels - for i, (num_layers, kernel_size, dilation) in enumerate( - zip((12, 24, 16), (3, 3, 3), (1, 2, 2))): - block = CAMDenseTDNNBlock( - num_layers=num_layers, - in_channels=channels, - out_channels=growth_rate, - bn_channels=bn_size * growth_rate, - kernel_size=kernel_size, - dilation=dilation, - config_str=config_str, - memory_efficient=memory_efficient) - self.xvector.add_module('block%d' % (i + 1), block) - channels = channels + num_layers * growth_rate - self.xvector.add_module( - 'transit%d' % (i + 1), - TransitLayer( - channels, channels // 2, bias=False, - config_str=config_str)) - channels //= 2 - - self.xvector.add_module('out_nonlinear', - get_nonlinear(config_str, channels)) - - if self.output_level == 'segment': - self.xvector.add_module('stats', StatsPool()) - self.xvector.add_module( - 'dense', - DenseLayer( - channels * 2, embedding_size, config_str='batchnorm_')) - else: - assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. ' - - for m in self.modules(): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.kaiming_normal_(m.weight.data) - if m.bias is not None: - nn.init.zeros_(m.bias) - - def forward(self, x): - x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) - x = self.head(x) - x = self.xvector(x) - if self.output_level == 'frame': - x = x.transpose(1, 2) - return x - - -def get_nonlinear(config_str, channels): - nonlinear = nn.Sequential() - for name in config_str.split('-'): - if name == 'relu': - nonlinear.add_module('relu', nn.ReLU(inplace=True)) - elif name == 'prelu': - nonlinear.add_module('prelu', nn.PReLU(channels)) - elif name == 'batchnorm': - nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) - elif name == 'batchnorm_': - nonlinear.add_module('batchnorm', - nn.BatchNorm1d(channels, affine=False)) - else: - raise ValueError('Unexpected module ({}).'.format(name)) - return nonlinear - - -def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): - mean = x.mean(dim=dim) - std = x.std(dim=dim, unbiased=unbiased) - stats = torch.cat([mean, std], dim=-1) - if keepdim: - stats = stats.unsqueeze(dim=dim) - return stats - - -class StatsPool(nn.Module): - - def forward(self, x): - return statistics_pooling(x) - - -class TDNNLayer(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - bias=False, - config_str='batchnorm-relu'): - super(TDNNLayer, self).__init__() - if padding < 0: - assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( - kernel_size) - padding = (kernel_size - 1) // 2 * dilation - self.linear = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias) - self.nonlinear = get_nonlinear(config_str, out_channels) - - def forward(self, x): - x = self.linear(x) - x = self.nonlinear(x) - return x - - def extract_feature(audio): features = [] for au in audio: @@ -387,116 +156,6 @@ class CAMLayer(nn.Module): return seg -class CAMDenseTDNNLayer(nn.Module): - - def __init__(self, - in_channels, - out_channels, - bn_channels, - kernel_size, - stride=1, - dilation=1, - bias=False, - config_str='batchnorm-relu', - memory_efficient=False): - super(CAMDenseTDNNLayer, self).__init__() - assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( - kernel_size) - padding = (kernel_size - 1) // 2 * dilation - self.memory_efficient = memory_efficient - self.nonlinear1 = get_nonlinear(config_str, in_channels) - self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) - self.nonlinear2 = get_nonlinear(config_str, bn_channels) - self.cam_layer = CAMLayer( - bn_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias) - - def bn_function(self, x): - return self.linear1(self.nonlinear1(x)) - - def forward(self, x): - if self.training and self.memory_efficient: - x = cp.checkpoint(self.bn_function, x) - else: - x = self.bn_function(x) - x = self.cam_layer(self.nonlinear2(x)) - return x - - -class CAMDenseTDNNBlock(nn.ModuleList): - - def __init__(self, - num_layers, - in_channels, - out_channels, - bn_channels, - kernel_size, - stride=1, - dilation=1, - bias=False, - config_str='batchnorm-relu', - memory_efficient=False): - super(CAMDenseTDNNBlock, self).__init__() - for i in range(num_layers): - layer = CAMDenseTDNNLayer( - in_channels=in_channels + i * out_channels, - out_channels=out_channels, - bn_channels=bn_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - bias=bias, - config_str=config_str, - memory_efficient=memory_efficient) - self.add_module('tdnnd%d' % (i + 1), layer) - - def forward(self, x): - for layer in self: - x = torch.cat([x, layer(x)], dim=1) - return x - - -class TransitLayer(nn.Module): - - def __init__(self, - in_channels, - out_channels, - bias=True, - config_str='batchnorm-relu'): - super(TransitLayer, self).__init__() - self.nonlinear = get_nonlinear(config_str, in_channels) - self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) - - def forward(self, x): - x = self.nonlinear(x) - x = self.linear(x) - return x - - -class DenseLayer(nn.Module): - - def __init__(self, - in_channels, - out_channels, - bias=False, - config_str='batchnorm-relu'): - super(DenseLayer, self).__init__() - self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) - self.nonlinear = get_nonlinear(config_str, out_channels) - - def forward(self, x): - if len(x.shape) == 2: - x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) - else: - x = self.linear(x) - x = self.nonlinear(x) - return x - def postprocess(segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray) -> list: assert len(segments) == len(labels) @@ -592,300 +251,3 @@ def distribute_spk(sentence_list, sd_time_list): d['spk'] = sentence_spk sd_sentence_list.append(d) return sd_sentence_list - - -class AFF(nn.Module): - - def __init__(self, channels=64, r=4): - super(AFF, self).__init__() - inter_channels = int(channels // r) - - self.local_att = nn.Sequential( - nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.SiLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - - def forward(self, x, ds_y): - xa = torch.cat((x, ds_y), dim=1) - x_att = self.local_att(xa) - x_att = 1.0 + torch.tanh(x_att) - xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att) - - return xo - - -class TSTP(nn.Module): - """ - Temporal statistics pooling, concatenate mean and std, which is used in - x-vector - Comment: simple concatenation can not make full use of both statistics - """ - - def __init__(self, **kwargs): - super(TSTP, self).__init__() - - def forward(self, x): - # The last dimension is the temporal axis - pooling_mean = x.mean(dim=-1) - pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) - pooling_mean = pooling_mean.flatten(start_dim=1) - pooling_std = pooling_std.flatten(start_dim=1) - - stats = torch.cat((pooling_mean, pooling_std), 1) - return stats - - -class ReLU(nn.Hardtanh): - - def __init__(self, inplace=False): - super(ReLU, self).__init__(0, 20, inplace) - - def __repr__(self): - inplace_str = 'inplace' if self.inplace else '' - return self.__class__.__name__ + ' (' \ - + inplace_str + ')' - - -def conv1x1(in_planes, out_planes, stride=1): - "1x1 convolution without padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, - padding=0, bias=False) - - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlockERes2Net(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): - super(BasicBlockERes2Net, self).__init__() - width = int(math.floor(planes * (baseWidth / 64.0))) - self.conv1 = conv1x1(in_planes, width * scale, stride) - self.bn1 = nn.BatchNorm2d(width * scale) - self.nums = scale - - convs = [] - bns = [] - for i in range(self.nums): - convs.append(conv3x3(width, width)) - bns.append(nn.BatchNorm2d(width)) - self.convs = nn.ModuleList(convs) - self.bns = nn.ModuleList(bns) - self.relu = ReLU(inplace=True) - - self.conv3 = conv1x1(width * scale, planes * self.expansion) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False), - nn.BatchNorm2d(self.expansion * planes)) - self.stride = stride - self.width = width - self.scale = scale - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - spx = torch.split(out, self.width, 1) - for i in range(self.nums): - if i == 0: - sp = spx[i] - else: - sp = sp + spx[i] - sp = self.convs[i](sp) - sp = self.relu(self.bns[i](sp)) - if i == 0: - out = sp - else: - out = torch.cat((out, sp), 1) - - out = self.conv3(out) - out = self.bn3(out) - - residual = self.shortcut(x) - out += residual - out = self.relu(out) - - return out - - -class BasicBlockERes2Net_diff_AFF(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): - super(BasicBlockERes2Net_diff_AFF, self).__init__() - width = int(math.floor(planes * (baseWidth / 64.0))) - self.conv1 = conv1x1(in_planes, width * scale, stride) - self.bn1 = nn.BatchNorm2d(width * scale) - - self.nums = scale - - convs = [] - fuse_models = [] - bns = [] - for i in range(self.nums): - convs.append(conv3x3(width, width)) - bns.append(nn.BatchNorm2d(width)) - for j in range(self.nums - 1): - fuse_models.append(AFF(channels=width)) - - self.convs = nn.ModuleList(convs) - self.bns = nn.ModuleList(bns) - self.fuse_models = nn.ModuleList(fuse_models) - self.relu = ReLU(inplace=True) - - self.conv3 = conv1x1(width * scale, planes * self.expansion) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False), - nn.BatchNorm2d(self.expansion * planes)) - self.stride = stride - self.width = width - self.scale = scale - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - spx = torch.split(out, self.width, 1) - for i in range(self.nums): - if i == 0: - sp = spx[i] - else: - sp = self.fuse_models[i - 1](sp, spx[i]) - - sp = self.convs[i](sp) - sp = self.relu(self.bns[i](sp)) - if i == 0: - out = sp - else: - out = torch.cat((out, sp), 1) - - out = self.conv3(out) - out = self.bn3(out) - - residual = self.shortcut(x) - out += residual - out = self.relu(out) - - return out - - -class ERes2Net(nn.Module): - def __init__(self, - block=BasicBlockERes2Net, - block_fuse=BasicBlockERes2Net_diff_AFF, - num_blocks=[3, 4, 6, 3], - m_channels=64, - feat_dim=80, - embedding_size=192, - pooling_func='TSTP', - two_emb_layer=False): - super(ERes2Net, self).__init__() - self.in_planes = m_channels - self.feat_dim = feat_dim - self.embedding_size = embedding_size - self.stats_dim = int(feat_dim / 8) * m_channels * 8 - self.two_emb_layer = two_emb_layer - - self.conv1 = nn.Conv2d(1, - m_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(m_channels) - self.layer1 = self._make_layer(block, - m_channels, - num_blocks[0], - stride=1) - self.layer2 = self._make_layer(block, - m_channels * 2, - num_blocks[1], - stride=2) - self.layer3 = self._make_layer(block_fuse, - m_channels * 4, - num_blocks[2], - stride=2) - self.layer4 = self._make_layer(block_fuse, - m_channels * 8, - num_blocks[3], - stride=2) - - self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, - bias=False) - self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, - bias=False) - self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, - bias=False) - self.fuse_mode12 = AFF(channels=m_channels * 8) - self.fuse_mode123 = AFF(channels=m_channels * 16) - self.fuse_mode1234 = AFF(channels=m_channels * 32) - - self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 - self.pool = TSTP(in_dim=self.stats_dim * block.expansion) - self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, - embedding_size) - if self.two_emb_layer: - self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) - self.seg_2 = nn.Linear(embedding_size, embedding_size) - else: - self.seg_bn_1 = nn.Identity() - self.seg_2 = nn.Identity() - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) - - x = x.unsqueeze_(1) - out = F.relu(self.bn1(self.conv1(x))) - out1 = self.layer1(out) - out2 = self.layer2(out1) - out1_downsample = self.layer1_downsample(out1) - fuse_out12 = self.fuse_mode12(out2, out1_downsample) - out3 = self.layer3(out2) - fuse_out12_downsample = self.layer2_downsample(fuse_out12) - fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) - out4 = self.layer4(out3) - fuse_out123_downsample = self.layer3_downsample(fuse_out123) - fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) - stats = self.pool(fuse_out1234) - - embed_a = self.seg_1(stats) - if self.two_emb_layer: - out = F.relu(embed_a) - out = self.seg_bn_1(out) - embed_b = self.seg_2(out) - return embed_b - else: - return embed_a