Merge branch 'dev_gzf' of github.com:alibaba-damo-academy/FunASR into dev_gzf

add
This commit is contained in:
游雁 2024-02-29 17:19:46 +08:00
commit 5c8af9c7e5
18 changed files with 1716 additions and 5 deletions

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
multilingual_wavs = [
"example_zh-CN.mp3",
"example_en.mp3",
"example_ja.mp3",
"example_ko.mp3",
]
model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4")
for wav_id in multilingual_wavs:
wav_file = f"{model.model_path}/examples/{wav_id}"
res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
print("detect sample {}: {}".format(wav_id, res))

View File

@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
multilingual_wavs=[
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
]
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4")
for wav in multilingual_wavs:
rec_result = inference_pipeline(input=wav, inference_clip_length=250)
print(rec_result)

View File

@ -0,0 +1,102 @@
from typing import Tuple
import torch
import torch.nn as nn
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
@tables.register("frontend_classes", "WhisperFrontend")
class WhisperFrontend(nn.Module):
"""Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
fs: int = 16000,
whisper_model: str = "large-v3",
do_pad_trim: bool = True,
):
super().__init__()
assert fs == 16000
self.fs = fs
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
self.pad_samples = N_SAMPLES
self.frame_shift = self.hop_length
self.lfr_n = 1
if whisper_model == "large-v3" or whisper_model == "large":
self.n_mels = 128
else:
self.n_mels = 80
self.mel_filters = whisper.audio.mel_filters
self.do_pad_trim = do_pad_trim
if do_pad_trim:
self.pad_or_trim = whisper.pad_or_trim
assert whisper_model in whisper.available_models()
def output_size(self) -> int:
return self.n_mels
def log_mel_spectrogram(
self,
audio: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
window = torch.hann_window(self.win_length).to(audio.device)
stft = torch.stft(
audio, self.n_fft, self.hop_length, window=window, return_complex=True
)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if ilens is not None:
olens = ilens // self.hop_length
else:
olens = None
log_spec = torch.maximum(
log_spec,
log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
)
log_spec = (log_spec + 4.0) / 4.0
return log_spec, olens
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
if self.do_pad_trim:
feat = self.pad_or_trim(input[i], self.pad_samples)
else:
feat = input[i]
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
feats.append(feat[0])
feats_lens.append(feat_len)
feats_lens = torch.as_tensor(feats_lens)
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens

View File

View File

@ -0,0 +1,167 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import copy
from typing import Any, List, Tuple
import torch
from torch import nn
import whisper
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
class OpenAIWhisperDecoderWarp(nn.Module):
"""Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
dropout_rate: float = 0.0,
whisper_model: str = "small",
download_dir: str = None,
use_padmask: bool = False,
):
super().__init__()
assert whisper_model in whisper.available_models()
_model = whisper.load_model(
whisper_model, download_root=download_dir, device="cpu"
)
self.decoders = copy.deepcopy(_model.decoder)
attention_dim = self.decoders.token_embedding.embedding_dim
# note that originally Whisper doesn't use dropouts
self.dropout = torch.nn.Dropout(dropout_rate)
self.decoders.train()
del _model
self.use_padmask = use_padmask
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt, memory = ys_in_pad, hs_pad
tgt = (
self.decoders.token_embedding(tgt)
+ self.decoders.positional_embedding[: tgt.size(1)]
)
tgt = self.dropout(tgt)
x = tgt.to(memory.dtype)
if self.use_padmask:
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
else:
memory_mask = None
for layer, block in enumerate(self.decoders.blocks):
x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
if layer < len(self.decoders.blocks) - 1:
x = self.dropout(x)
x = self.decoders.ln(x)
x = (
x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return x, ys_in_lens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
NOTE (Shih-Lun):
cache implementation is ignored for now
for simplicity & correctness
"""
x = (
self.decoders.token_embedding(tgt)
+ self.decoders.positional_embedding[: tgt.size(1)]
)
x = self.dropout(x)
x = x.to(memory.dtype)
for layer, block in enumerate(self.decoders.blocks):
x = block(x, memory, mask=self.decoders.mask)
if layer < len(self.decoders.blocks) - 1:
x = self.dropout(x)
x = self.decoders.ln(x)
y = x[:, -1]
y = (
y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
).float()
y = torch.log_softmax(y, dim=-1)
return y, None
def score(self, ys, state, x):
"""Score."""
logp, state = self.forward_one_step(
ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask
)
return logp.squeeze(0), state
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# batch decoding, dummy mask is passed
logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
return logp, None

View File

@ -0,0 +1,119 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import copy
from typing import Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
import whisper
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.specaug.specaug import SpecAug
from funasr.register import tables
@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
class OpenAIWhisperEncoderWarp(nn.Module):
"""Transformer-based Speech Encoder from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
dropout_rate: float = 0.0,
whisper_model: str = "small",
download_dir: str = None,
use_specaug: bool = False,
use_padmask: bool = False,
specaug_conf: Union[dict, None] = None,
):
super().__init__()
# note that originally Whisper doesn't use dropouts
self.dropout = torch.nn.Dropout(dropout_rate)
assert whisper_model in whisper.available_models()
_model = whisper.load_model(
whisper_model, download_root=download_dir, device="cpu"
)
self.encoders = copy.deepcopy(_model.encoder)
self.encoders.train()
del _model
if use_specaug:
self.specaug = SpecAug(**specaug_conf)
else:
self.specaug = None
self.use_padmask = use_padmask
def whisper_encode(
self,
input: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
x = F.gelu(self.encoders.conv1(input))
x = F.gelu(self.encoders.conv2(x))
x = x.permute(0, 2, 1)
n_frames = x.size(1)
max_pos = self.encoders.positional_embedding.size(0)
if n_frames <= max_pos:
x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
else:
# due to positional encoding, audios >30 sec won't be accepted
x = x[:, :max_pos, :] + self.encoders.positional_embedding
if ilens is not None:
olens = (
1
+ (
ilens
- self.encoders.conv2.kernel_size[0]
+ 2 * self.encoders.conv2.padding[0]
)
// self.encoders.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
olens = None
if self.use_padmask:
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
else:
padding_mask = None
x = self.dropout(x)
for layer, block in enumerate(self.encoders.blocks):
x = block(x)
if layer < len(self.encoders.blocks) - 1:
x = self.dropout(x)
x = self.encoders.ln_post(x)
return x, olens
def output_size(self) -> int:
# dummy output size
return self.encoders.conv2.weight.shape[0]
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
feats, feats_lens = xs_pad, ilens
if self.specaug is not None and self.encoders.training:
feats = torch.transpose(feats, 1, 2)
feats, feats_lens = self.specaug(feats, feats_lens)
feats = torch.transpose(feats, 1, 2)
xs_pad, olens = self.whisper_encode(feats, feats_lens)
return xs_pad, olens, None

View File

@ -0,0 +1,428 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
from funasr.models.whisper_lid.eres2net.fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
def conv1x1(in_planes, out_planes, stride=1):
"1x1 convolution without padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlockERes2Net(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
bns = []
for i in range(self.nums):
convs.append(conv3x3(width, width))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
fuse_models = []
bns = []
for i in range(self.nums):
convs.append(conv3x3(width, width))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self._output_size = embedding_size
self.conv1 = nn.Conv2d(1,
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
# Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
bias=False)
# Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def output_size(self) -> int:
return self._output_size
def forward(self, x, ilens):
# assert x.shape[1] == ilens.max()
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
stats = self.pool(fuse_out1234, olens)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
class BasicBlockRes2Net(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockRes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale - 1
convs = []
bns = []
for i in range(self.nums):
convs.append(conv3x3(width, width))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = torch.cat((out, spx[self.nums]), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class Res2Net(nn.Module):
def __init__(self,
block=BasicBlockRes2Net,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(Res2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1,
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block,
m_channels * 8,
num_blocks[3],
stride=2)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
stats = self.pool(out)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a

View File

@ -0,0 +1,29 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torch.nn as nn
class AFF(nn.Module):
def __init__(self, channels=64, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)
self.local_att = nn.Sequential(
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.SiLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
def forward(self, x, ds_y):
xa = torch.cat((x, ds_y), dim=1)
x_att = self.local_att(xa)
x_att = 1.0 + torch.tanh(x_att)
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
return xo

View File

@ -0,0 +1,118 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
import torch
import torch.nn as nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, **kwargs):
super(TAP, self).__init__()
def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean
class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, **kwargs):
super(TSDP, self).__init__()
def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):
super(TSTP, self).__init__()
def forward(self, x, olens):
# The last dimension is the temporal axis
masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
x_masked = x * masks
sum_without_padding = torch.sum(x_masked, axis=-1)
count_without_padding = torch.sum(masks, axis=-1)
mean_without_padding = sum_without_padding / count_without_padding
var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
pooling_mean = mean_without_padding
pooling_std = torch.sqrt(var_without_padding + 1e-8)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
super(ASTP, self).__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
kernel_size=1) # equals V and k in the paper
def forward(self, x):
"""
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(x.shape) == 4:
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
assert len(x.shape) == 3
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh(
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(var.clamp(min=1e-10))
return torch.cat([mean, std], dim=1)

View File

@ -0,0 +1,17 @@
import torch
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
class SimpleAvg(AbsEncoder):
def __init__(self, feat_dim):
super(SimpleAvg, self).__init__()
self.feat_dim = feat_dim
def forward(self, x, ilens):
mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
return avg_x
def output_size(self) -> int:
return self.feat_dim

View File

@ -0,0 +1,25 @@
from funasr.register import tables
from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
@tables.register("lid_predictor_classes", "LidPredictor")
class LidPredictor(ERes2Net):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(LidPredictor, self).__init__(
block=block,
block_fuse=block_fuse,
num_blocks=num_blocks,
m_channels=m_channels,
feat_dim=feat_dim,
embedding_size=embedding_size,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer
)

View File

@ -0,0 +1,665 @@
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import numpy as np
import torch.nn as nn
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@tables.register("model_classes", "OpenAIWhisperModel")
class OpenAIWhisperModel(nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if decoder is not None:
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(decoder_conf)
if ctc_weight > 0.0:
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.interctc_weight = interctc_weight
# self.error_calculator = None
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
self.error_calculator = None
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# decoder: CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# decoder: Attention decoder branch
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
speech, speech_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def init_beam_search(self,
**kwargs,
):
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(
ctc=ctc
)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
ctc=kwargs.get("decoding_ctc_weight", 0.5),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 10),
weights=weights,
scorers=scorers,
sos=self.sos,
eos=self.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
)
self.beam_search = beam_search
def inference(self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text_postprocessed
return results, meta_data
@tables.register("model_classes", "OpenAIWhisperLIDModel")
class OpenAIWhisperLIDModel(nn.Module):
"""WhisperEncoder and EResNet based LID Model"""
def __init__(
self,
vocab_size: int,
specaug: str = None,
specaug_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
lid_predictor: str = None,
lid_predictor_conf: dict = None,
proj_dim: int = None,
clip_frames: int = None,
random_clip: bool = False,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
lid_predictor = lid_predictor_class(**lid_predictor_conf)
if encoder.output_size() != proj_dim:
self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim)
else:
self.proj_layer = None
self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
self.criterion_lid = LabelSmoothingLoss(
size=vocab_size,
padding_idx=-1,
smoothing=0.0,
normalize_length=False,
)
self.specaug = specaug
self.encoder = encoder
self.lid_predictor = lid_predictor
self.clip_frames = clip_frames
self.random_clip = random_clip
self.normalize = None
self.beam_search = None
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
def forward(self,
speech: torch.Tensor, # may be padding
speech_lengths: torch.Tensor, # actual length
lid: torch.Tensor, # lid label, (batch_size, 1)
lid_lengths: torch.Tensor,
):
assert lid.shape[1] == 1
batch_size = speech.shape[0]
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# re-generate encoder_out
if self.clip_frames is None:
reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
for i, enc_length in enumerate(encoder_out_lens):
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
else:
reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
if self.random_clip:
for i, enc_length in enumerate(encoder_out_lens):
if enc_length <= self.clip_frames:
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
encoder_out_lens[i] = enc_length
else:
max_start_index = enc_length.item() - self.clip_frames
start_index = np.random.randint(0, max_start_index + 1)
reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
encoder_out_lens[i] = self.clip_frames
else:
for i, enc_length in enumerate(encoder_out_lens):
enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
encoder_out_lens[i] = enc_length
if self.proj_layer is not None:
reduced_encoder_out = self.proj_layer(reduced_encoder_out)
lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D)
lid_logits = self.output_layer(lid_output) # (B, num_classes)
loss = self.criterion_lid(lid_logits[:, None, :], lid)
with torch.no_grad():
_, predicted_lid = torch.max(lid_logits, 1)
correct = (predicted_lid == lid[:, 0]).sum().item()
lid_acc = correct * 1.0 / lid_logits.shape[0]
stats = dict()
stats["batch_size"] = batch_size
stats["loss"] = torch.clone(loss.detach())
stats["acc"] = lid_acc
stats["token_length"] = speech_lengths.max()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech = speech.permute(0, 2, 1)
# suit for whisper padding
padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
speech = speech.permute(0, 2, 1)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
speech, speech_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def inference(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
enc, enc_out_lens = self.encode(speech, speech_lengths)
inference_clip_length = kwargs.get("inference_clip_length", None)
if self.clip_frames is not None:
if inference_clip_length is None:
reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
for i, enc_length in enumerate(enc_out_lens):
enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
reduced_enc[i, :enc_length] = enc[i, :enc_length]
enc_out_lens[i] = enc_length
else:
assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
for i, enc_length in enumerate(enc_out_lens):
enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
reduced_enc[i, :enc_length] = enc[i, :enc_length]
enc_out_lens[i] = enc_length
else:
reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
for i, enc_length in enumerate(enc_out_lens):
reduced_enc[i, :enc_length] = enc[i, :enc_length]
if self.proj_layer is not None:
reduced_enc = self.proj_layer(reduced_enc)
lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D)
lid_logits = self.output_layer(lid_output) # (B, num_classes)
_, predicted_lid_index = torch.max(lid_logits, 1)
predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
lid_writer = self.writer["lid"]
lid_writer[key[0]] = predicted_lid
results = [{"key": key[0], "lid": predicted_lid}]
return results, meta_data

View File

@ -1380,7 +1380,7 @@ sampleClientRun(){
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
;;
Python)
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode offline --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
pre_cmd="pip3 install click>=8.0.4"
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"

View File

@ -1455,7 +1455,7 @@ sampleClientRun(){
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
;;
Python)
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode offline --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
pre_cmd="pip3 install click>=8.0.4"
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"

View File

@ -1380,7 +1380,7 @@ sampleClientRun(){
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
;;
Python)
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode 2pass --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
pre_cmd="pip3 install click>=8.0.4"
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"

View File

@ -59,7 +59,7 @@ wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_sa
For illustration, we will use the Python language client, which supports audio formats (.wav, .pcm) and a multi-file list wav.scp input.
```shell
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass
python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass
```
------------------

View File

@ -75,7 +75,7 @@ wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_sa
```
我们以Python语言客户端为例进行说明支持音频格式.wav, .pcm以及多文件列表wav.scp输入其他版本客户端请参考文档[点击此处](#客户端用法详解))。
```shell
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass
python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass
```
------------------