FunASR/funasr/utils/speaker_utils.py

892 lines
29 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
"""
import math
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import nn
import io
import os
from typing import Any, Dict, List, Union
import numpy as np
import librosa as sf
import torch
import torchaudio
import logging
from funasr.utils.modelscope_file import File
from collections import OrderedDict
import torchaudio.compliance.kaldi as Kaldi
def check_audio_list(audio: list):
audio_dur = 0
for i in range(len(audio)):
seg = audio[i]
assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
assert int(seg[1] * 16000) - int(
seg[0] * 16000
) == seg[2].shape[
0], 'modelscope error: audio data in list is inconsistent with time length.'
if i > 0:
assert seg[0] >= audio[
i - 1][1], 'modelscope error: Wrong time stamps.'
audio_dur += seg[1] - seg[0]
return audio_dur
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def sv_preprocess(inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
config = {
'seg_dur': 1.5,
'seg_shift': 0.75,
}
def seg_chunk(seg_data):
seg_st = seg_data[0]
data = seg_data[2]
chunk_len = int(config['seg_dur'] * fs)
chunk_shift = int(config['seg_shift'] * fs)
last_chunk_ed = 0
seg_res = []
for chunk_st in range(0, data.shape[0], chunk_shift):
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
if chunk_ed <= last_chunk_ed:
break
last_chunk_ed = chunk_ed
chunk_st = max(0, chunk_ed - chunk_len)
chunk_data = data[chunk_st:chunk_ed]
if chunk_data.shape[0] < chunk_len:
chunk_data = np.pad(chunk_data,
(0, chunk_len - chunk_data.shape[0]),
'constant')
seg_res.append([
chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
chunk_data
])
return seg_res
segs = []
for i, s in enumerate(vad_segments):
segs.extend(seg_chunk(s))
return segs
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class FCM(nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(
block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(
block, m_channels, num_blocks[0], stride=2)
self.conv2 = nn.Conv2d(
m_channels,
m_channels,
kernel_size=3,
stride=(2, 1),
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
class CAMPPlus(nn.Module):
def __init__(self,
feat_dim=80,
embedding_size=192,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True,
output_level='segment'):
super(CAMPPlus, self).__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
self.output_level = output_level
self.xvector = nn.Sequential(
OrderedDict([
('tdnn',
TDNNLayer(
channels,
init_channels,
5,
stride=2,
dilation=1,
padding=-1,
config_str=config_str)),
]))
channels = init_channels
for i, (num_layers, kernel_size, dilation) in enumerate(
zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
block = CAMDenseTDNNBlock(
num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
'transit%d' % (i + 1),
TransitLayer(
channels, channels // 2, bias=False,
config_str=config_str))
channels //= 2
self.xvector.add_module('out_nonlinear',
get_nonlinear(config_str, channels))
if self.output_level == 'segment':
self.xvector.add_module('stats', StatsPool())
self.xvector.add_module(
'dense',
DenseLayer(
channels * 2, embedding_size, config_str='batchnorm_'))
else:
assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
if self.output_level == 'frame':
x = x.transpose(1, 2)
return x
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x):
return statistics_pooling(x)
class TDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
super(TDNNLayer, self).__init__()
if padding < 0:
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
x = self.nonlinear(x)
return x
def extract_feature(audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
class CAMLayer(nn.Module):
def __init__(self,
bn_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
bias,
reduction=2):
super(CAMLayer, self).__init__()
self.linear_local = nn.Conv1d(
bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.linear_local(x)
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
context = self.relu(self.linear1(context))
m = self.sigmoid(self.linear2(context))
return y * m
def seg_pooling(self, x, seg_len=100, stype='avg'):
if stype == 'avg':
seg = F.avg_pool1d(
x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
elif stype == 'max':
seg = F.max_pool1d(
x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
else:
raise ValueError('Wrong segment pooling type.')
shape = seg.shape
seg = seg.unsqueeze(-1).expand(*shape,
seg_len).reshape(*shape[:-1], -1)
seg = seg[..., :x.shape[-1]]
return seg
class CAMDenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(
bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
if self.training and self.memory_efficient:
x = cp.checkpoint(self.bn_function, x)
else:
x = self.bn_function(x)
x = self.cam_layer(self.nonlinear2(x))
return x
class CAMDenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = CAMDenseTDNNLayer(
in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class TransitLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True,
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
x = self.linear(x)
return x
class DenseLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
def postprocess(segments: list, vad_segments: list,
labels: np.ndarray, embeddings: np.ndarray) -> list:
assert len(segments) == len(labels)
labels = correct_labels(labels)
distribute_res = []
for i in range(len(segments)):
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
# merge the same speakers chronologically
distribute_res = merge_seque(distribute_res)
# accquire speaker center
spk_embs = []
for i in range(labels.max() + 1):
spk_emb = embeddings[labels == i].mean(0)
spk_embs.append(spk_emb)
spk_embs = np.stack(spk_embs)
def is_overlapped(t1, t2):
if t1 > t2 + 1e-4:
return True
return False
# distribute the overlap region
for i in range(1, len(distribute_res)):
if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
distribute_res[i][0] = p
distribute_res[i - 1][1] = p
# smooth the result
distribute_res = smooth(distribute_res)
return distribute_res
def correct_labels(labels):
labels_id = 0
id2id = {}
new_labels = []
for i in labels:
if i not in id2id:
id2id[i] = labels_id
labels_id += 1
new_labels.append(id2id[i])
return np.array(new_labels)
def merge_seque(distribute_res):
res = [distribute_res[0]]
for i in range(1, len(distribute_res)):
if distribute_res[i][2] != res[-1][2] or distribute_res[i][
0] > res[-1][1]:
res.append(distribute_res[i])
else:
res[-1][1] = distribute_res[i][1]
return res
def smooth(res, mindur=1):
# short segments are assigned to nearest speakers.
for i in range(len(res)):
res[i][0] = round(res[i][0], 2)
res[i][1] = round(res[i][1], 2)
if res[i][1] - res[i][0] < mindur:
if i == 0:
res[i][2] = res[i + 1][2]
elif i == len(res) - 1:
res[i][2] = res[i - 1][2]
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
res[i][2] = res[i - 1][2]
else:
res[i][2] = res[i + 1][2]
# merge the speakers
res = merge_seque(res)
return res
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
sentence_start = d['ts_list'][0][0]
sentence_end = d['ts_list'][-1][1]
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
spk_st, spk_ed, spk = sd_time
spk_st = spk_st*1000
spk_ed = spk_ed*1000
overlap = max(
min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
if overlap > max_overlap:
max_overlap = overlap
sentence_spk = spk
d['spk'] = sentence_spk
sd_sentence_list.append(d)
return sd_sentence_list
class AFF(nn.Module):
def __init__(self, channels=64, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)
self.local_att = nn.Sequential(
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.SiLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
def forward(self, x, ds_y):
xa = torch.cat((x, ds_y), dim=1)
x_att = self.local_att(xa)
x_att = 1.0 + torch.tanh(x_att)
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
return xo
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):
super(TSTP, self).__init__()
def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
def conv1x1(in_planes, out_planes, stride=1):
"1x1 convolution without padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlockERes2Net(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
bns = []
for i in range(self.nums):
convs.append(conv3x3(width, width))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
fuse_models = []
bns = []
for i in range(self.nums):
convs.append(conv3x3(width, width))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1,
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2,
bias=False)
self.fuse_mode12 = AFF(channels=m_channels * 8)
self.fuse_mode123 = AFF(channels=m_channels * 16)
self.fuse_mode1234 = AFF(channels=m_channels * 32)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = TSTP(in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
stats = self.pool(fuse_out1234)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a