diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py index a3eebf9d9..b769b8577 100644 --- a/funasr/utils/speaker_utils.py +++ b/funasr/utils/speaker_utils.py @@ -108,54 +108,6 @@ def extract_feature(audio): return features -class CAMLayer(nn.Module): - - def __init__(self, - bn_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - bias, - reduction=2): - super(CAMLayer, self).__init__() - self.linear_local = nn.Conv1d( - bn_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias) - self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) - self.relu = nn.ReLU(inplace=True) - self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - y = self.linear_local(x) - context = x.mean(-1, keepdim=True) + self.seg_pooling(x) - context = self.relu(self.linear1(context)) - m = self.sigmoid(self.linear2(context)) - return y * m - - def seg_pooling(self, x, seg_len=100, stype='avg'): - if stype == 'avg': - seg = F.avg_pool1d( - x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) - elif stype == 'max': - seg = F.max_pool1d( - x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) - else: - raise ValueError('Wrong segment pooling type.') - shape = seg.shape - seg = seg.unsqueeze(-1).expand(*shape, - seg_len).reshape(*shape[:-1], -1) - seg = seg[..., :x.shape[-1]] - return seg - - def postprocess(segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray) -> list: assert len(segments) == len(labels)