mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
127 lines
4.8 KiB
Python
127 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
|
|
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
from contextlib import contextmanager
|
|
from distutils.version import LooseVersion
|
|
|
|
from funasr.register import tables
|
|
from funasr.models.campplus.utils import extract_feature
|
|
from funasr.utils.load_utils import load_audio_text_image_video
|
|
from funasr.models.campplus.components import DenseLayer, StatsPool, \
|
|
TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
from torch.cuda.amp import autocast
|
|
else:
|
|
# Nothing to do if torch<1.6.0
|
|
@contextmanager
|
|
def autocast(enabled=True):
|
|
yield
|
|
|
|
|
|
@tables.register("model_classes", "CAMPPlus")
|
|
class CAMPPlus(torch.nn.Module):
|
|
def __init__(self,
|
|
feat_dim=80,
|
|
embedding_size=192,
|
|
growth_rate=32,
|
|
bn_size=4,
|
|
init_channels=128,
|
|
config_str='batchnorm-relu',
|
|
memory_efficient=True,
|
|
output_level='segment',
|
|
**kwargs,):
|
|
super().__init__()
|
|
|
|
self.head = FCM(feat_dim=feat_dim)
|
|
channels = self.head.out_channels
|
|
self.output_level = output_level
|
|
|
|
self.xvector = torch.nn.Sequential(
|
|
OrderedDict([
|
|
|
|
('tdnn',
|
|
TDNNLayer(channels,
|
|
init_channels,
|
|
5,
|
|
stride=2,
|
|
dilation=1,
|
|
padding=-1,
|
|
config_str=config_str)),
|
|
]))
|
|
channels = init_channels
|
|
for i, (num_layers, kernel_size,
|
|
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
|
block = CAMDenseTDNNBlock(num_layers=num_layers,
|
|
in_channels=channels,
|
|
out_channels=growth_rate,
|
|
bn_channels=bn_size * growth_rate,
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
config_str=config_str,
|
|
memory_efficient=memory_efficient)
|
|
self.xvector.add_module('block%d' % (i + 1), block)
|
|
channels = channels + num_layers * growth_rate
|
|
self.xvector.add_module(
|
|
'transit%d' % (i + 1),
|
|
TransitLayer(channels,
|
|
channels // 2,
|
|
bias=False,
|
|
config_str=config_str))
|
|
channels //= 2
|
|
|
|
self.xvector.add_module(
|
|
'out_nonlinear', get_nonlinear(config_str, channels))
|
|
|
|
if self.output_level == 'segment':
|
|
self.xvector.add_module('stats', StatsPool())
|
|
self.xvector.add_module(
|
|
'dense',
|
|
DenseLayer(
|
|
channels * 2, embedding_size, config_str='batchnorm_'))
|
|
else:
|
|
assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
|
|
torch.nn.init.kaiming_normal_(m.weight.data)
|
|
if m.bias is not None:
|
|
torch.nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
|
x = self.head(x)
|
|
x = self.xvector(x)
|
|
if self.output_level == 'frame':
|
|
x = x.transpose(1, 2)
|
|
return x
|
|
|
|
def inference(self,
|
|
data_in,
|
|
data_lengths=None,
|
|
key: list=None,
|
|
tokenizer=None,
|
|
frontend=None,
|
|
**kwargs,
|
|
):
|
|
# extract fbank feats
|
|
meta_data = {}
|
|
time1 = time.perf_counter()
|
|
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
|
|
time2 = time.perf_counter()
|
|
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
|
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
|
|
speech = speech.to(device=kwargs["device"])
|
|
time3 = time.perf_counter()
|
|
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
|
meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
|
|
results = [{"spk_embedding": self.forward(speech)}]
|
|
return results, meta_data |