diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index f61c0859d..402a91197 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -48,13 +48,13 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils.vad_utils import slice_padding_fbank -from funasr.utils.speaker_utils import (check_audio_list, - sv_preprocess, - sv_chunk, - CAMPPlus, - extract_feature, +from funasr.utils.speaker_utils import (check_audio_list, + sv_preprocess, + sv_chunk, + CAMPPlus, + extract_feature, postprocess, - distribute_spk) + distribute_spk, ERes2Net) from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.utils.cluster_backend import ClusterBackend from funasr.utils.modelscope_utils import get_cache_dir @@ -819,6 +819,10 @@ def inference_paraformer_vad_speaker( ) sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin") + if not os.path.exists(sv_model_file): + sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt") + if not os.path.exists(sv_model_file): + raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file)) if param_dict is not None: hotword_list_or_file = param_dict.get('hotword') @@ -944,8 +948,14 @@ def inference_paraformer_vad_speaker( ##### speaker_verification ##### ################################## # load sv model - sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) - sv_model = CAMPPlus() + sv_model_dict = torch.load(sv_model_file) + print(f'load sv model params: {sv_model_file}') + if os.path.basename(sv_model_file) == "campplus_cn_common.bin": + sv_model = CAMPPlus() + else: + sv_model = ERes2Net() + if ngpu > 0: + sv_model.cuda() sv_model.load_state_dict(sv_model_dict) sv_model.eval() cb_model = ClusterBackend() @@ -969,9 +979,11 @@ def inference_paraformer_vad_speaker( embs = [] for x in wavs: x = extract_feature([x]) + if ngpu > 0: + x = x.cuda() embs.append(sv_model(x)) embs = torch.cat(embs) - embeddings.append(embs.detach().numpy()) + embeddings.append(embs.cpu().detach().numpy()) embeddings = np.concatenate(embeddings) labels = cb_model(embeddings) sv_output = postprocess(segments, vad_segments, labels, embeddings) diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py index edaf58b75..df3eca7d8 100644 --- a/funasr/utils/speaker_utils.py +++ b/funasr/utils/speaker_utils.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN """ +import math import torch import torch.nn.functional as F @@ -590,4 +591,301 @@ def distribute_spk(sentence_list, sd_time_list): sentence_spk = spk d['spk'] = sentence_spk sd_sentence_list.append(d) - return sd_sentence_list \ No newline at end of file + return sd_sentence_list + + +class AFF(nn.Module): + + def __init__(self, channels=64, r=4): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + self.local_att = nn.Sequential( + nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.SiLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + def forward(self, x, ds_y): + xa = torch.cat((x, ds_y), dim=1) + x_att = self.local_att(xa) + x_att = 1.0 + torch.tanh(x_att) + xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att) + + return xo + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, **kwargs): + super(TSTP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + "1x1 convolution without padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=64, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block_fuse, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block_fuse, + m_channels * 8, + num_blocks[3], + stride=2) + + self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, + bias=False) + self.fuse_mode12 = AFF(channels=m_channels * 8) + self.fuse_mode123 = AFF(channels=m_channels * 16) + self.fuse_mode1234 = AFF(channels=m_channels * 32) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = TSTP(in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + stats = self.pool(fuse_out1234) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a