mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
minmo
This commit is contained in:
parent
5abb8367a3
commit
62c6f50a1d
File diff suppressed because it is too large
Load Diff
@ -1,365 +0,0 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_nonlinear(config_str, channels):
|
||||
nonlinear = nn.Sequential()
|
||||
for name in config_str.split('-'):
|
||||
if name == 'relu':
|
||||
nonlinear.add_module('relu', nn.ReLU(inplace=True))
|
||||
elif name == 'prelu':
|
||||
nonlinear.add_module('prelu', nn.PReLU(channels))
|
||||
elif name == 'batchnorm':
|
||||
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
|
||||
elif name == 'batchnorm_':
|
||||
nonlinear.add_module('batchnorm',
|
||||
nn.BatchNorm1d(channels, affine=False))
|
||||
else:
|
||||
raise ValueError('Unexpected module ({}).'.format(name))
|
||||
return nonlinear
|
||||
|
||||
|
||||
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
|
||||
mean = x.mean(dim=dim)
|
||||
std = x.std(dim=dim, unbiased=unbiased)
|
||||
stats = torch.cat([mean, std], dim=-1)
|
||||
if keepdim:
|
||||
stats = stats.unsqueeze(dim=dim)
|
||||
return stats
|
||||
|
||||
|
||||
class StatsPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return statistics_pooling(x)
|
||||
|
||||
|
||||
class TDNNLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=False,
|
||||
config_str='batchnorm-relu'):
|
||||
super(TDNNLayer, self).__init__()
|
||||
if padding < 0:
|
||||
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
||||
kernel_size)
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.linear = nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
self.nonlinear = get_nonlinear(config_str, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.nonlinear(x)
|
||||
return x
|
||||
|
||||
|
||||
class CAMLayer(nn.Module):
|
||||
def __init__(self,
|
||||
bn_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
bias,
|
||||
reduction=2):
|
||||
super(CAMLayer, self).__init__()
|
||||
self.linear_local = nn.Conv1d(bn_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.linear_local(x)
|
||||
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
|
||||
context = self.relu(self.linear1(context))
|
||||
m = self.sigmoid(self.linear2(context))
|
||||
return y * m
|
||||
|
||||
def seg_pooling(self, x, seg_len=100, stype='avg'):
|
||||
if stype == 'avg':
|
||||
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
||||
elif stype == 'max':
|
||||
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
||||
else:
|
||||
raise ValueError('Wrong segment pooling type.')
|
||||
shape = seg.shape
|
||||
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
|
||||
seg = seg[..., :x.shape[-1]]
|
||||
return seg
|
||||
|
||||
|
||||
class CAMDenseTDNNLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
bn_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
bias=False,
|
||||
config_str='batchnorm-relu',
|
||||
memory_efficient=False):
|
||||
super(CAMDenseTDNNLayer, self).__init__()
|
||||
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
||||
kernel_size)
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.memory_efficient = memory_efficient
|
||||
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
||||
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
||||
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
||||
self.cam_layer = CAMLayer(bn_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
def bn_function(self, x):
|
||||
return self.linear1(self.nonlinear1(x))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn_function(x)
|
||||
x = self.cam_layer(self.nonlinear2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CAMDenseTDNNBlock(nn.ModuleList):
|
||||
def __init__(self,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
bn_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
bias=False,
|
||||
config_str='batchnorm-relu',
|
||||
memory_efficient=False):
|
||||
super(CAMDenseTDNNBlock, self).__init__()
|
||||
for i in range(num_layers):
|
||||
layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
|
||||
out_channels=out_channels,
|
||||
bn_channels=bn_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
config_str=config_str,
|
||||
memory_efficient=memory_efficient)
|
||||
self.add_module('tdnnd%d' % (i + 1), layer)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self:
|
||||
x = torch.cat([x, layer(x)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class TransitLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
config_str='batchnorm-relu'):
|
||||
super(TransitLayer, self).__init__()
|
||||
self.nonlinear = get_nonlinear(config_str, in_channels)
|
||||
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.nonlinear(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DenseLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=False,
|
||||
config_str='batchnorm-relu'):
|
||||
super(DenseLayer, self).__init__()
|
||||
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
||||
self.nonlinear = get_nonlinear(config_str, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 2:
|
||||
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
else:
|
||||
x = self.linear(x)
|
||||
x = self.nonlinear(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicResBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicResBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=(stride, 1),
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1),
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class FCM(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicResBlock,
|
||||
num_blocks=[2, 2],
|
||||
m_channels=32,
|
||||
feat_dim=80):
|
||||
super(FCM, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(m_channels)
|
||||
self.out_channels = m_channels * (feat_dim // 8)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
|
||||
shape = out.shape
|
||||
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
||||
return out
|
||||
|
||||
|
||||
class CAMPPlus(nn.Module):
|
||||
def __init__(self,
|
||||
feat_dim=80,
|
||||
embedding_size=512,
|
||||
growth_rate=32,
|
||||
bn_size=4,
|
||||
init_channels=128,
|
||||
config_str='batchnorm-relu',
|
||||
memory_efficient=True):
|
||||
super(CAMPPlus, self).__init__()
|
||||
|
||||
self.head = FCM(feat_dim=feat_dim)
|
||||
channels = self.head.out_channels
|
||||
|
||||
self.xvector = nn.Sequential(
|
||||
OrderedDict([
|
||||
|
||||
('tdnn',
|
||||
TDNNLayer(channels,
|
||||
init_channels,
|
||||
5,
|
||||
stride=2,
|
||||
dilation=1,
|
||||
padding=-1,
|
||||
config_str=config_str)),
|
||||
]))
|
||||
channels = init_channels
|
||||
for i, (num_layers, kernel_size,
|
||||
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
||||
block = CAMDenseTDNNBlock(num_layers=num_layers,
|
||||
in_channels=channels,
|
||||
out_channels=growth_rate,
|
||||
bn_channels=bn_size * growth_rate,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
config_str=config_str,
|
||||
memory_efficient=memory_efficient)
|
||||
self.xvector.add_module('block%d' % (i + 1), block)
|
||||
channels = channels + num_layers * growth_rate
|
||||
self.xvector.add_module(
|
||||
'transit%d' % (i + 1),
|
||||
TransitLayer(channels,
|
||||
channels // 2,
|
||||
bias=False,
|
||||
config_str=config_str))
|
||||
channels //= 2
|
||||
|
||||
self.xvector.add_module(
|
||||
'out_nonlinear', get_nonlinear(config_str, channels))
|
||||
|
||||
self.xvector.add_module('stats', StatsPool())
|
||||
self.xvector.add_module(
|
||||
'dense',
|
||||
DenseLayer(channels * 2, embedding_size, config_str='prelu'))
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.kaiming_normal_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = self.head(x)
|
||||
x = self.xvector(x)
|
||||
return x
|
||||
@ -1,258 +0,0 @@
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CTC(torch.nn.Module):
|
||||
"""CTC module.
|
||||
|
||||
Args:
|
||||
odim: dimension of outputs
|
||||
encoder_output_size: number of encoder projection units
|
||||
dropout_rate: dropout rate (0.0 ~ 1.0)
|
||||
ctc_type: builtin or warpctc
|
||||
reduce: reduce the CTC loss into a scalar
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
odim: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
ctc_type: str = "builtin",
|
||||
reduce: bool = True,
|
||||
ignore_nan_grad: bool = True,
|
||||
length_normalize: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
||||
self.ctc_type = ctc_type
|
||||
self.ignore_nan_grad = ignore_nan_grad
|
||||
self.length_normalize = length_normalize
|
||||
|
||||
if self.ctc_type == "builtin":
|
||||
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
|
||||
else:
|
||||
raise ValueError(
|
||||
f'ctc_type must be "builtin": {self.ctc_type}'
|
||||
)
|
||||
|
||||
self.reduce = reduce
|
||||
|
||||
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
|
||||
if self.ctc_type == "builtin":
|
||||
th_pred = th_pred.log_softmax(2)
|
||||
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
||||
|
||||
if loss.requires_grad and self.ignore_nan_grad:
|
||||
# ctc_grad: (L, B, O)
|
||||
ctc_grad = loss.grad_fn(torch.ones_like(loss))
|
||||
ctc_grad = ctc_grad.sum([0, 2])
|
||||
indices = torch.isfinite(ctc_grad)
|
||||
size = indices.long().sum()
|
||||
if size == 0:
|
||||
# Return as is
|
||||
logging.warning(
|
||||
"All samples in this mini-batch got nan grad."
|
||||
" Returning nan value instead of CTC loss"
|
||||
)
|
||||
elif size != th_pred.size(1):
|
||||
logging.warning(
|
||||
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
|
||||
" samples got nan grad."
|
||||
" These were ignored for CTC loss."
|
||||
)
|
||||
|
||||
# Create mask for target
|
||||
target_mask = torch.full(
|
||||
[th_target.size(0)],
|
||||
1,
|
||||
dtype=torch.bool,
|
||||
device=th_target.device,
|
||||
)
|
||||
s = 0
|
||||
for ind, le in enumerate(th_olen):
|
||||
if not indices[ind]:
|
||||
target_mask[s : s + le] = 0
|
||||
s += le
|
||||
|
||||
# Calc loss again using maksed data
|
||||
loss = self.ctc_loss(
|
||||
th_pred[:, indices, :],
|
||||
th_target[target_mask],
|
||||
th_ilen[indices],
|
||||
th_olen[indices],
|
||||
)
|
||||
th_ilen, th_olen = th_ilen[indices], th_olen[indices]
|
||||
else:
|
||||
size = th_pred.size(1)
|
||||
|
||||
if self.length_normalize is not None:
|
||||
if self.length_normalize == "olen":
|
||||
loss = loss / th_olen
|
||||
else:
|
||||
loss = loss / th_ilen
|
||||
|
||||
if self.reduce:
|
||||
# Batch-size average
|
||||
loss = loss.sum() / size
|
||||
else:
|
||||
loss = loss / size
|
||||
return loss
|
||||
|
||||
elif self.ctc_type == "warpctc":
|
||||
# warpctc only supports float32
|
||||
th_pred = th_pred.to(dtype=torch.float32)
|
||||
|
||||
th_target = th_target.cpu().int()
|
||||
th_ilen = th_ilen.cpu().int()
|
||||
th_olen = th_olen.cpu().int()
|
||||
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
||||
if self.reduce:
|
||||
# NOTE: sum() is needed to keep consistency since warpctc
|
||||
# return as tensor w/ shape (1,)
|
||||
# but builtin return as tensor w/o shape (scalar).
|
||||
loss = loss.sum()
|
||||
return loss
|
||||
|
||||
elif self.ctc_type == "gtnctc":
|
||||
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
|
||||
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
|
||||
"""Calculate CTC loss.
|
||||
|
||||
Args:
|
||||
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
||||
hlens: batch of lengths of hidden state sequences (B)
|
||||
ys_pad: batch of padded character id sequence tensor (B, Lmax)
|
||||
ys_lens: batch of lengths of character sequence (B)
|
||||
"""
|
||||
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
|
||||
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
|
||||
|
||||
if self.ctc_type == "gtnctc":
|
||||
# gtn expects list form for ys
|
||||
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
|
||||
else:
|
||||
# ys_hat: (B, L, D) -> (L, B, D)
|
||||
ys_hat = ys_hat.transpose(0, 1)
|
||||
# (B, L) -> (BxL,)
|
||||
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
|
||||
|
||||
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
|
||||
device=hs_pad.device, dtype=hs_pad.dtype
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def softmax(self, hs_pad):
|
||||
"""softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def log_softmax(self, hs_pad):
|
||||
"""log_softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def argmax(self, hs_pad):
|
||||
"""argmax of frame activations
|
||||
|
||||
Args:
|
||||
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: argmax applied 2d tensor (B, Tmax)
|
||||
"""
|
||||
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
|
||||
def ctc_forced_align(
|
||||
log_probs: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
blank: int = 0,
|
||||
ignore_id: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Align a CTC label sequence to an emission.
|
||||
|
||||
Args:
|
||||
log_probs (Tensor): log probability of CTC emission output.
|
||||
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
|
||||
`C` is the number of characters in alphabet including blank.
|
||||
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
|
||||
where `L` is the target length.
|
||||
input_lengths (Tensor):
|
||||
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
|
||||
target_lengths (Tensor):
|
||||
Lengths of the targets. 1-D Tensor of shape `(B,)`.
|
||||
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
|
||||
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
|
||||
"""
|
||||
targets[targets == ignore_id] = blank
|
||||
|
||||
batch_size, input_time_size, _ = log_probs.size()
|
||||
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
|
||||
|
||||
_t_a_r_g_e_t_s_ = torch.cat(
|
||||
(
|
||||
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
|
||||
torch.full_like(targets[:, :1], blank),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
diff_labels = torch.cat(
|
||||
(
|
||||
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
|
||||
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
|
||||
padding_num = 2
|
||||
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
|
||||
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
|
||||
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
|
||||
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
|
||||
|
||||
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
|
||||
|
||||
for t in range(1, input_time_size):
|
||||
prev = torch.stack(
|
||||
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
|
||||
)
|
||||
prev_max_value, prev_max_idx = prev.max(dim=0)
|
||||
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
|
||||
backpointers[:, t, padding_num:] = prev_max_idx
|
||||
|
||||
l1l2 = best_score.gather(
|
||||
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
|
||||
)
|
||||
|
||||
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
|
||||
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
|
||||
|
||||
for t in range(input_time_size - 1, 0, -1):
|
||||
target_indices = path[:, t]
|
||||
prev_max_idx = backpointers[bsz_indices, t, target_indices]
|
||||
path[:, t - 1] += target_indices - prev_max_idx
|
||||
|
||||
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
|
||||
return alignments
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,542 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from funasr.models.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
LegacyRelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
|
||||
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.models.transformer.utils.nets_utils import get_activation
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.utils.mask import subsequent_mask, causal_block_mask
|
||||
from funasr.models.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.utils.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError,
|
||||
check_short_utt, Conv2dSubsamplingPad
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.llm_asr.conformer_encoder import ConvolutionModule, EncoderLayer
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True,
|
||||
out_channels=None, name="conv", channel_first=True, stride=2, causal=False):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.channel_first = channel_first
|
||||
self.stride = stride
|
||||
self.causal = causal
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
# transpose conv doesn't support causal mode.
|
||||
assert not causal
|
||||
kernel_size = stride*2 + stride % 2
|
||||
padding = (kernel_size - stride) // 2
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, kernel_size, stride, padding)
|
||||
elif use_conv:
|
||||
# In this mode, first repeat interpolate, than conv with stride=1
|
||||
self.conv = nn.Conv1d(
|
||||
self.channels, self.out_channels, stride*2+1, stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
if not self.channel_first:
|
||||
inputs = inputs.transpose(1, 2).contiguous()
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
outputs = self.conv(inputs)
|
||||
if not self.channel_first:
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
return outputs, input_lengths * self.stride
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
if not self.causal:
|
||||
outputs = F.pad(outputs, (self.stride, self.stride))
|
||||
else:
|
||||
outputs = F.pad(outputs, (self.stride*2, 0))
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
if not self.channel_first:
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
return outputs, input_lengths * self.stride
|
||||
|
||||
|
||||
class PreLookaheadLayer(nn.Module):
|
||||
def __init__(self, channels: int, pre_lookahead_len:int = 1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.conv1 = nn.Conv1d(
|
||||
channels, channels,
|
||||
kernel_size=pre_lookahead_len+1,
|
||||
stride=1, padding=0,
|
||||
)
|
||||
self.conv2 = nn.Conv1d(
|
||||
channels, channels,
|
||||
kernel_size=3, stride=1, padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs, ilens):
|
||||
"""
|
||||
inputs: (batch_size, seq_len, channels)
|
||||
"""
|
||||
outputs = inputs.transpose(1, 2).contiguous()
|
||||
# look ahead
|
||||
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
|
||||
outputs = F.leaky_relu(self.conv1(outputs))
|
||||
# outputs
|
||||
outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
|
||||
mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
|
||||
# residual connection
|
||||
outputs = (outputs + inputs) * mask
|
||||
|
||||
return outputs, ilens
|
||||
|
||||
|
||||
class UpsampleConformerEncoder(nn.Module):
|
||||
"""Progressive upsampling Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
upsample_blocks: int = 3,
|
||||
upsample_attn_layers: int = 2,
|
||||
upsample_ratios: tuple = None,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = "legacy",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
causal: bool = False,
|
||||
skip: bool = False,
|
||||
channel_first: bool = False,
|
||||
use_causal_prob: float = None,
|
||||
pre_lookahead_len: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.causal = causal
|
||||
self.skip = skip
|
||||
self.channel_first = channel_first
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.use_causal_prob = use_causal_prob
|
||||
|
||||
if rel_pos_type == "legacy":
|
||||
if pos_enc_layer_type == "rel_pos":
|
||||
pos_enc_layer_type = "legacy_rel_pos"
|
||||
if selfattention_layer_type == "rel_selfattn":
|
||||
selfattention_layer_type = "legacy_rel_selfattn"
|
||||
elif rel_pos_type == "latest":
|
||||
assert selfattention_layer_type != "legacy_rel_selfattn"
|
||||
assert pos_enc_layer_type != "legacy_rel_pos"
|
||||
else:
|
||||
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
"Using legacy_rel_pos and it will be deprecated in the future."
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2dpad":
|
||||
self.embed = Conv2dSubsamplingPad(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if pre_lookahead_len is not None:
|
||||
self.pre_lookahead_layer = PreLookaheadLayer(output_size, pre_lookahead_len)
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate=0.0,
|
||||
),
|
||||
)
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
if upsample_ratios is None:
|
||||
upsample_ratios = [2] * upsample_blocks
|
||||
self.upsample_ratios = upsample_ratios
|
||||
assert upsample_blocks == len(upsample_ratios)
|
||||
for i in range(upsample_blocks):
|
||||
if not causal:
|
||||
upsample_conv_block = Upsample1D(
|
||||
channels=output_size, use_conv=False, use_conv_transpose=True,
|
||||
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=False,
|
||||
)
|
||||
else:
|
||||
upsample_conv_block = Upsample1D(
|
||||
channels=output_size, use_conv=True, use_conv_transpose=False,
|
||||
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=True,
|
||||
)
|
||||
upsample_attn_block = repeat(
|
||||
upsample_attn_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate=0.0,
|
||||
),
|
||||
)
|
||||
attn_input_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(output_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
self.upsample_blocks.append(nn.ModuleList([upsample_conv_block, attn_input_layer, upsample_attn_block]))
|
||||
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def rand_mix_masks(self, causal, noncausal):
|
||||
use_causal = (torch.rand([causal.shape[0], 1, 1]) <= self.uni_encoder_prob).to(causal)
|
||||
masks = use_causal * causal + (1 - use_causal) * noncausal
|
||||
return masks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
raw_input = xs_pad
|
||||
if self.channel_first:
|
||||
xs_pad = xs_pad.permute(0, 2, 1)
|
||||
|
||||
if ilens is not None:
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
else:
|
||||
masks = torch.ones(
|
||||
xs_pad.shape[0], 1, xs_pad.shape[1],
|
||||
dtype=torch.bool, device=xs_pad.device
|
||||
)
|
||||
|
||||
if self.use_causal_prob is not None:
|
||||
use_causal = (torch.rand([xs_pad.shape[0], 1, 1]) <= self.use_causal_prob).to(xs_pad)
|
||||
else:
|
||||
use_causal = torch.ones([xs_pad.shape[0], 1, 1]).to(xs_pad)
|
||||
|
||||
if self.causal:
|
||||
causal_mask = subsequent_mask(
|
||||
xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype
|
||||
).unsqueeze(0)
|
||||
causal_mask = masks & causal_mask
|
||||
# whether to train causal & non-causal in a single model
|
||||
masks = use_causal * causal_mask + (1 - use_causal) * masks
|
||||
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
or isinstance(self.embed, Conv2dSubsamplingPad)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
if self.pre_lookahead_len is not None:
|
||||
xs = xs_pad
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs = xs_pad[0]
|
||||
xs, _ = self.pre_lookahead_layer(xs, ilens)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = (xs, xs_pad[1])
|
||||
|
||||
# 1. modeling on inputs
|
||||
intermediate_outs = []
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
|
||||
# 2. progressive upsampling
|
||||
outs, olens = xs_pad, ilens
|
||||
total_ratio = 1
|
||||
for up_ratio, layer in zip(self.upsample_ratios, self.upsample_blocks):
|
||||
up_layer, attn_input_layer, attn_layer = layer
|
||||
|
||||
if isinstance(outs, tuple):
|
||||
outs = outs[0]
|
||||
outs, olens = up_layer(outs, olens)
|
||||
masks = (~make_pad_mask(olens)[:, None, :]).to(outs.device)
|
||||
total_ratio = total_ratio * up_ratio
|
||||
if self.causal:
|
||||
causal_mask = causal_block_mask(
|
||||
outs.shape[1], total_ratio, device=outs.device, dtype=masks.dtype
|
||||
).unsqueeze(0)
|
||||
causal_mask = masks & causal_mask
|
||||
masks = use_causal * causal_mask + (1 - use_causal) * masks
|
||||
outs = attn_input_layer(outs)
|
||||
outs, _ = attn_layer(outs, masks)
|
||||
|
||||
xs_pad = outs
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.channel_first:
|
||||
xs_pad = xs_pad.permute(0, 2, 1)
|
||||
|
||||
if self.skip:
|
||||
xs_pad = xs_pad + raw_input
|
||||
|
||||
# olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
|
||||
if ilens is not None:
|
||||
return xs_pad, olens, None
|
||||
else:
|
||||
return xs_pad
|
||||
|
||||
@ -1,625 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Tuple, Dict, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
import torch.nn.functional as F
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.models.llm_asr.label_smoothing_loss import LabelSmoothingLoss
|
||||
from copy import deepcopy
|
||||
from funasr.metrics.compute_acc import th_accuracy
|
||||
from funasr.models.transformer.utils.nets_utils import pad_list
|
||||
import random
|
||||
import numpy as np
|
||||
from funasr.utils.hinter import hint_once
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.models.llm_asr.tts_models.ctc_alignment import ctc_forced_align
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import itertools
|
||||
from distutils.version import LooseVersion
|
||||
from contextlib import contextmanager
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class NARCTCModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
vocab_size: int,
|
||||
encoder: Union[nn.Module, dict],
|
||||
decoder: Optional[nn.Module] = None,
|
||||
ctc_weight: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.decoder = decoder
|
||||
self.encoder = encoder if isinstance(encoder, nn.Module) else self.build_encoder(encoder)
|
||||
self.output_size = self.encoder.output_size()
|
||||
self.ignore_id = ignore_id
|
||||
self.vocab_size = vocab_size
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
# build ctc module
|
||||
from funasr.models.llm_asr.tts_models.ctc_alignment import CTC
|
||||
ctc_conf = kwargs.pop("ctc_conf", {})
|
||||
self.ctc = CTC(vocab_size, encoder_output_size=self.output_size, **ctc_conf)
|
||||
|
||||
self.text_embedding = torch.nn.Embedding(self.vocab_size, input_size)
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, input_size)
|
||||
xvec_size = kwargs.get("xvec_size", None)
|
||||
if xvec_size is not None:
|
||||
self.xvec_proj = torch.nn.Linear(xvec_size, input_size)
|
||||
else:
|
||||
self.xvec_proj = None
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
self.sos = vocab_size - 2
|
||||
self.eos = vocab_size - 1
|
||||
self.length_regulator_conf = kwargs.get("length_regulator_conf", None)
|
||||
if self.length_regulator_conf is not None:
|
||||
self.length_regulator = self.build_length_regulator()
|
||||
else:
|
||||
self.length_regulator = None
|
||||
|
||||
def build_encoder(self, encoder_conf: dict):
|
||||
if encoder_conf is None:
|
||||
assert hasattr(self, "encoder_conf"), \
|
||||
"function param encoder_conf is None and model doesn't has encoder_conf attribute either."
|
||||
encoder_conf = self.encoder_conf
|
||||
|
||||
encoder_name = encoder_conf.pop("name", "transformer")
|
||||
model = None
|
||||
if encoder_name == "transformer":
|
||||
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
|
||||
model = ConformerEncoder(
|
||||
**encoder_conf,
|
||||
input_size=self.input_size,
|
||||
use_cnn_module=False,
|
||||
macaron_style=False,
|
||||
)
|
||||
elif encoder_name == "conformer":
|
||||
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
|
||||
model = ConformerEncoder(
|
||||
**encoder_conf,
|
||||
input_size=self.input_size,
|
||||
)
|
||||
elif encoder_name == "upsampling_conformer":
|
||||
from funasr.models.llm_asr.tts_models.encoders import UpsampleConformerEncoder
|
||||
model = UpsampleConformerEncoder(
|
||||
**encoder_conf,
|
||||
input_size=self.input_size,
|
||||
)
|
||||
|
||||
encoder_conf["name"] = encoder_name
|
||||
|
||||
return model
|
||||
|
||||
def build_length_regulator(self):
|
||||
name = self.length_regulator_conf.pop("name", None)
|
||||
model = None
|
||||
if name == "upsampling":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator
|
||||
model = UpSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
|
||||
elif name == "downsampling":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator
|
||||
model = DownSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
|
||||
elif name == "interpolate":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator
|
||||
model = InterpolateRegulator(self.input_size, **self.length_regulator_conf)
|
||||
elif name == "upsampling_cif":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import UpsamplingCifRegulator
|
||||
model = UpsamplingCifRegulator(self.input_size, **self.length_regulator_conf)
|
||||
|
||||
self.length_regulator_conf["name"] = name
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def norm_and_sample_xvec(xvec, xvec_lengths):
|
||||
xvec_list = []
|
||||
for i, ilen in enumerate(xvec_lengths):
|
||||
idx = random.randint(0, ilen - 1)
|
||||
while torch.any(~torch.isfinite(xvec[i, idx])):
|
||||
idx = random.randint(0, ilen - 1)
|
||||
xvec_list.append(xvec[i, idx])
|
||||
rand_xvec = torch.vstack(xvec_list)
|
||||
rand_xvec = F.normalize(rand_xvec, dim=1)
|
||||
|
||||
return rand_xvec
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.decoder.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def model_forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None,
|
||||
xvec_lengths: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
# 0. Up-sampling text length
|
||||
if self.length_regulator is not None:
|
||||
text, text_lengths = self.length_regulator(text, text_lengths)
|
||||
|
||||
# 1. padding xvec
|
||||
if xvec is not None and self.xvec_proj is not None:
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
# random select a xvec from xvec matrix
|
||||
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
|
||||
xvec = self.xvec_proj(xvec)
|
||||
text = text + xvec.unsqueeze(1)
|
||||
hint_once("use xvec", "use_xvec")
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(text, text_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def predictor(
|
||||
self,
|
||||
am: torch.Tensor,
|
||||
am_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
alignment,
|
||||
):
|
||||
acoustic_embeds = []
|
||||
use_pred_num = 0
|
||||
for am_xs, enc_len, ali, y, y_lens in zip(am, am_lens, alignment, ys_pad, ys_pad_lens):
|
||||
pred = itertools.groupby(ali[:enc_len])
|
||||
|
||||
acoustic_embed = []
|
||||
_start = 0
|
||||
for pred_token, pred_frame in pred:
|
||||
_end = _start + len(list(pred_frame))
|
||||
if pred_token != 0:
|
||||
acoustic_embed.append(torch.mean(am_xs[_start:_end, :], 0, keepdim=True))
|
||||
_start = _end
|
||||
if len(acoustic_embed) != y_lens:
|
||||
acoustic_embeds.append(y[:y_lens])
|
||||
else:
|
||||
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
|
||||
use_pred_num += 1
|
||||
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
|
||||
return acoustic_embeds, use_pred_num / am.shape[0]
|
||||
|
||||
def force_align_text(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None,
|
||||
xvec_lengths: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
# plus one to speech token, to make index 0 represent <blank>,
|
||||
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
|
||||
speech = torch.where(speech != -1, speech + 1, speech)
|
||||
|
||||
encoder_out, encoder_out_lens = self.model_forward(
|
||||
text, text_lengths,
|
||||
xvec, xvec_lengths,
|
||||
**kwargs
|
||||
)
|
||||
log_probs = self.ctc.log_softmax(encoder_out)
|
||||
with torch.no_grad():
|
||||
alignment = ctc_forced_align(
|
||||
log_probs.float(),
|
||||
speech.long(),
|
||||
encoder_out_lens.long(),
|
||||
speech_lengths.long(),
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
aligned_token_emb, use_pred_ratio = self.predictor(
|
||||
encoder_out, encoder_out_lens,
|
||||
self.token_embedding(speech), speech_lengths,
|
||||
alignment,
|
||||
)
|
||||
|
||||
loss = 0
|
||||
states = dict(
|
||||
use_pred_ratio=use_pred_ratio,
|
||||
)
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, logits = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
states["loss_ctc"] = loss_ctc.item()
|
||||
loss = loss + self.ctc_weight * loss_ctc
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
states["loss_att"] = loss_att.item()
|
||||
loss = loss + (1.0 - self.ctc_weight) * loss_att
|
||||
|
||||
states["loss"] = loss.item()
|
||||
return loss, aligned_token_emb, states
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
logits = self.ctc.log_softmax(encoder_out)
|
||||
|
||||
return loss_ctc, logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None,
|
||||
xvec_lengths: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...), speech tokens
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length), text tokens
|
||||
text_lengths: (Batch, )
|
||||
xvec: (Batch, Length, ...) x-vectors
|
||||
xvec_lengths: (Batch, )
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
# plus one to speech token, to make index 0 represent <blank>,
|
||||
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
|
||||
speech = torch.where(speech != -1, speech + 1, speech)
|
||||
|
||||
# embed text inputs
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.text_embedding(torch.clamp(text, min=0)) * mask
|
||||
|
||||
encoder_out, encoder_out_lens = self.model_forward(
|
||||
text, text_lengths,
|
||||
xvec, xvec_lengths,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict(
|
||||
batch_size=float(batch_size),
|
||||
text_len=float(text.shape[1]),
|
||||
enc_len=float(encoder_out.shape[1]),
|
||||
speech_len=float(speech.shape[1]),
|
||||
token_text_ratio=float(speech.shape[1])/float(text.shape[1]),
|
||||
)
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, logits = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def topp_sampling(self, probs, top_p=0.8):
|
||||
sorted_value, sorted_idx = probs.sort(descending=True, stable=True)
|
||||
cumulative_probs = torch.cumsum(sorted_value, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = False
|
||||
indices_to_remove = sorted_idx[sorted_indices_to_remove]
|
||||
probs[indices_to_remove] = 0
|
||||
top_ids = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
return top_ids
|
||||
|
||||
def sampling_ids(self, enc_outs, sampling="greedy", blank_penalty=None, return_probs=False):
|
||||
probs = self.ctc.softmax(enc_outs)
|
||||
if blank_penalty > 0:
|
||||
probs[:, :, 0] = probs[:, :, 0] * blank_penalty
|
||||
|
||||
# top-p sampling
|
||||
if "." in sampling:
|
||||
sampling = float(sampling)
|
||||
tokens = self.topp_sampling(probs, top_p=sampling)
|
||||
tokens = torch.tensor(tokens, dtype=torch.long).to(probs.device)
|
||||
# top-k sampling
|
||||
elif sampling.isdigit():
|
||||
sampling = int(sampling)
|
||||
probs = probs.topk(sampling)
|
||||
tokens = probs.multinomial(1, replacement=True)
|
||||
else:
|
||||
if sampling == "greedy":
|
||||
tokens = torch.argmax(probs, dim=-1)
|
||||
elif "threshold_" in sampling:
|
||||
threshold = float(sampling.split("_")[1])
|
||||
hint_once(f"Decoding mode: blank threshold={threshold:.2f}", "decoding_mode")
|
||||
# mask out blank according to threshold
|
||||
mask = probs[:, :, 0] > threshold
|
||||
probs[:, :, 0] = probs[:, :, 0] * mask
|
||||
tokens = torch.argmax(probs, dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f"sampling method {sampling} not implemented")
|
||||
|
||||
if not return_probs:
|
||||
return tokens
|
||||
|
||||
return tokens, probs
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor, text_lengths: torch.Tensor,
|
||||
xvec=None, xvec_lengths=None,
|
||||
sampling="greedy",
|
||||
blank_penalty: float = 0.0,
|
||||
text_is_embedding=False,
|
||||
return_hidden=False,
|
||||
**kwargs,
|
||||
):
|
||||
device = text.device
|
||||
# use casual mode at inference stage
|
||||
self.encoder.use_causal_prob = kwargs.get("use_causal_prob", 1.0)
|
||||
hint_once(f"use_causal_prob {self.encoder.use_causal_prob}.", "use_causal_prob")
|
||||
# embed text inputs
|
||||
if not text_is_embedding:
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.text_embedding(torch.clamp(text, min=0)) * mask
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.model_forward(
|
||||
text, text_lengths,
|
||||
xvec, xvec_lengths,
|
||||
)
|
||||
fa_tokens, enc_probs = self.sampling_ids(
|
||||
encoder_out,
|
||||
sampling=sampling,
|
||||
blank_penalty=blank_penalty,
|
||||
return_probs=True,
|
||||
)
|
||||
reduced_fa_tokens = []
|
||||
for pred_token, pred_frame in itertools.groupby(fa_tokens[0].cpu().tolist()):
|
||||
if pred_token != 0:
|
||||
reduced_fa_tokens.append(pred_token)
|
||||
else:
|
||||
reduced_fa_tokens.extend(list(pred_frame))
|
||||
fa_tokens = torch.tensor([reduced_fa_tokens]).to(fa_tokens)
|
||||
# remove blanks (id=0) and convert token ids into the original format
|
||||
tokens = [[x-1] for x in fa_tokens[0].cpu().tolist() if x > 0]
|
||||
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
|
||||
|
||||
if not return_hidden:
|
||||
return tokens
|
||||
|
||||
acoustic_embs, acoustic_emb_lens = [], []
|
||||
for idx, (prob, enc) in enumerate(zip(enc_probs, encoder_out)):
|
||||
pred = itertools.groupby(prob.argmax(-1).cpu())
|
||||
acs_emb = []
|
||||
_start = 0
|
||||
for pred_token, pred_frame in pred:
|
||||
_end = _start + len(list(pred_frame))
|
||||
if pred_token != 0 and pred_token != -1:
|
||||
acs_emb.append(torch.mean(enc[_start:_end, :], 0, keepdim=True))
|
||||
_start = _end
|
||||
acs_emb = torch.cat(acs_emb, dim=0)
|
||||
acoustic_embs.append(acs_emb)
|
||||
acoustic_emb_lens.append(acs_emb.shape[0])
|
||||
|
||||
acoustic_embs = pad_list(acoustic_embs, 0.0)
|
||||
acoustic_emb_lens = torch.tensor(acoustic_emb_lens, dtype=torch.int64, device=device)
|
||||
|
||||
return (tokens, fa_tokens), acoustic_embs, acoustic_emb_lens
|
||||
|
||||
|
||||
class NARCTCProbModel(NARCTCModel):
|
||||
def __init__(self, input_size: int, vocab_size: int, encoder: Union[nn.Module, dict],
|
||||
decoder: Optional[nn.Module] = None, ctc_weight: float = 0.5, ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0, length_normalized_loss: bool = False, **kwargs):
|
||||
super().__init__(input_size, vocab_size, encoder, decoder, ctc_weight, ignore_id, lsm_weight,
|
||||
length_normalized_loss, **kwargs)
|
||||
|
||||
def predictor(
|
||||
self,
|
||||
am_probs: torch.Tensor,
|
||||
am_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
alignment,
|
||||
):
|
||||
acoustic_embeds = []
|
||||
use_pred_num = 0
|
||||
for probs, enc_len, ali, y, y_lens in zip(am_probs, am_lens, alignment, ys_pad, ys_pad_lens):
|
||||
pred = itertools.groupby(ali[:enc_len])
|
||||
|
||||
acoustic_embed = []
|
||||
_start = 0
|
||||
for pred_token, pred_frame in pred:
|
||||
_end = _start + len(list(pred_frame))
|
||||
if pred_token != 0:
|
||||
acoustic_embed.append(torch.mean(probs[_start:_end, :], 0, keepdim=True))
|
||||
_start = _end
|
||||
if len(acoustic_embed) != y_lens:
|
||||
acoustic_embeds.append(F.one_hot(y[:y_lens], self.vocab_size).float())
|
||||
else:
|
||||
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
|
||||
use_pred_num += 1
|
||||
acoustic_embeds[-1] = torch.matmul(acoustic_embeds[-1], self.token_embedding.weight)
|
||||
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
|
||||
return acoustic_embeds, use_pred_num / am_probs.shape[0]
|
||||
|
||||
def force_align_text(self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor,
|
||||
text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None,
|
||||
xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
|
||||
# plus one to speech token, to make index 0 represent <blank>,
|
||||
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
|
||||
speech = torch.where(speech != -1, speech + 1, speech)
|
||||
|
||||
encoder_out, encoder_out_lens = self.model_forward(
|
||||
text, text_lengths,
|
||||
xvec, xvec_lengths,
|
||||
**kwargs
|
||||
)
|
||||
log_probs = self.ctc.log_softmax(encoder_out)
|
||||
with torch.no_grad():
|
||||
alignment = ctc_forced_align(
|
||||
log_probs.float(),
|
||||
speech.long(),
|
||||
encoder_out_lens.long(),
|
||||
speech_lengths.long(),
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
aligned_token_emb, use_pred_ratio = self.predictor(
|
||||
log_probs.float(), encoder_out_lens.long(),
|
||||
speech.long(), speech_lengths.long(),
|
||||
alignment,
|
||||
)
|
||||
|
||||
loss = 0
|
||||
states = dict(
|
||||
use_pred_ratio=use_pred_ratio,
|
||||
)
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, logits = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
states["loss_ctc"] = loss_ctc.item()
|
||||
loss = loss + self.ctc_weight * loss_ctc
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, speech, speech_lengths
|
||||
)
|
||||
states["loss_att"] = loss_att.item()
|
||||
loss = loss + (1.0 - self.ctc_weight) * loss_att
|
||||
|
||||
states["loss"] = loss.item()
|
||||
return loss, aligned_token_emb, states
|
||||
|
||||
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor, xvec=None, xvec_lengths=None, sampling="greedy",
|
||||
blank_penalty: float = 0.0, text_is_embedding=False, return_hidden=False, **kwargs):
|
||||
device = text.device
|
||||
# embed text inputs
|
||||
if not text_is_embedding:
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.text_embedding(torch.clamp(text, min=0)) * mask
|
||||
|
||||
# 0. Up-sampling text length
|
||||
if self.length_regulator is not None:
|
||||
text, text_lengths = self.length_regulator(text, text_lengths)
|
||||
|
||||
# 1. padding xvec
|
||||
if xvec is not None and self.xvec_proj is not None:
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
# random select a xvec from xvec matrix
|
||||
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
|
||||
xvec = self.xvec_proj(xvec)
|
||||
text = text + xvec.unsqueeze(1)
|
||||
hint_once("use xvec", "use_xvec")
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.model_forward(
|
||||
text, text_lengths,
|
||||
xvec, xvec_lengths,
|
||||
)
|
||||
tokens, enc_probs = self.sampling_ids(
|
||||
encoder_out,
|
||||
sampling=sampling,
|
||||
blank_penalty=blank_penalty,
|
||||
return_probs=True,
|
||||
)
|
||||
# remove blanks (id=0) and convert token ids into the original format
|
||||
tokens = [[x - 1] for x in tokens[0].cpu().tolist() if x > 0]
|
||||
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
|
||||
|
||||
if not return_hidden:
|
||||
return tokens
|
||||
|
||||
acoustic_embs = self.token_embedding(tokens.squeeze(-1))
|
||||
acoustic_emb_lens = torch.tensor([acoustic_embs.shape[1]], dtype=torch.int64, device=device)
|
||||
|
||||
return tokens, acoustic_embs, acoustic_emb_lens
|
||||
@ -1,761 +0,0 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Decoder definition."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from funasr.models.transformer.attention import MultiHeadedAttention
|
||||
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
|
||||
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr.models.transformer.embedding import PositionalEncoding
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.transformer.utils.lightconv import LightweightConvolution
|
||||
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
|
||||
from funasr.models.transformer.utils.mask import subsequent_mask
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
src_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, 1, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, 1, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
|
||||
"""Base class of Transfomer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
causal=True,
|
||||
):
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
self.causal = causal
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
# Must set by the inheritance
|
||||
self.decoders = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
if self.causal:
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
x = self.embed(tgt)
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
|
||||
Args:
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
x = self.embed(tgt)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.decoders)
|
||||
new_cache = []
|
||||
for c, decoder in zip(cache, self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.output_layer is not None:
|
||||
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
||||
|
||||
return y, new_cache
|
||||
|
||||
def score(self, ys, state, x):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.decoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
torch.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
# batch decoding
|
||||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
||||
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
|
||||
class TransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
embeds_id: int = -1,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
self.embeds_id = embeds_id
|
||||
self.attention_dim = attention_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
# x = self.embed(tgt)
|
||||
x = tgt
|
||||
embeds_outputs = None
|
||||
for layer_id, decoder in enumerate(self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if layer_id == self.embeds_id:
|
||||
embeds_outputs = x
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
if embeds_outputs is not None:
|
||||
return x, olens, embeds_outputs
|
||||
else:
|
||||
return x, olens
|
||||
|
||||
|
||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,346 +0,0 @@
|
||||
@
|
||||
@
|
||||
@
|
||||
@
|
||||
@!
|
||||
@"
|
||||
@#
|
||||
@$
|
||||
@'
|
||||
@(
|
||||
@)
|
||||
@*
|
||||
@,
|
||||
@-
|
||||
@.
|
||||
@/
|
||||
@:
|
||||
@;
|
||||
@<
|
||||
@>
|
||||
@?
|
||||
@[
|
||||
@]
|
||||
@^
|
||||
@_
|
||||
@`
|
||||
@a_c1
|
||||
@a_c2
|
||||
@a_c3
|
||||
@a_c4
|
||||
@a_c5
|
||||
@aa0
|
||||
@aa1
|
||||
@aa2
|
||||
@ae0
|
||||
@ae1
|
||||
@ae2
|
||||
@ah0
|
||||
@ah1
|
||||
@ah2
|
||||
@ai_c1
|
||||
@ai_c2
|
||||
@ai_c3
|
||||
@ai_c4
|
||||
@ai_c5
|
||||
@an_c1
|
||||
@an_c2
|
||||
@an_c3
|
||||
@an_c4
|
||||
@an_c5
|
||||
@ang_c1
|
||||
@ang_c2
|
||||
@ang_c3
|
||||
@ang_c4
|
||||
@ang_c5
|
||||
@ao0
|
||||
@ao1
|
||||
@ao2
|
||||
@ao_c1
|
||||
@ao_c2
|
||||
@ao_c3
|
||||
@ao_c4
|
||||
@ao_c5
|
||||
@aw0
|
||||
@aw1
|
||||
@aw2
|
||||
@ay0
|
||||
@ay1
|
||||
@ay2
|
||||
@b
|
||||
@b_c
|
||||
@c_c
|
||||
@ch
|
||||
@ch_c
|
||||
@d
|
||||
@d_c
|
||||
@dh
|
||||
@e_c1
|
||||
@e_c2
|
||||
@e_c3
|
||||
@e_c4
|
||||
@e_c5
|
||||
@eh0
|
||||
@eh1
|
||||
@eh2
|
||||
@ei_c1
|
||||
@ei_c2
|
||||
@ei_c3
|
||||
@ei_c4
|
||||
@ei_c5
|
||||
@en_c1
|
||||
@en_c2
|
||||
@en_c3
|
||||
@en_c4
|
||||
@en_c5
|
||||
@eng_c1
|
||||
@eng_c2
|
||||
@eng_c3
|
||||
@eng_c4
|
||||
@eng_c5
|
||||
@er0
|
||||
@er1
|
||||
@er2
|
||||
@er_c1
|
||||
@er_c2
|
||||
@er_c3
|
||||
@er_c4
|
||||
@er_c5
|
||||
@ey0
|
||||
@ey1
|
||||
@ey2
|
||||
@f
|
||||
@f_c
|
||||
@g
|
||||
@g_c
|
||||
@ga
|
||||
@ge
|
||||
@go
|
||||
@h_c
|
||||
@hh
|
||||
@i_c1
|
||||
@i_c2
|
||||
@i_c3
|
||||
@i_c4
|
||||
@i_c5
|
||||
@ia_c1
|
||||
@ia_c2
|
||||
@ia_c3
|
||||
@ia_c4
|
||||
@ia_c5
|
||||
@ian_c1
|
||||
@ian_c2
|
||||
@ian_c3
|
||||
@ian_c4
|
||||
@ian_c5
|
||||
@iang_c1
|
||||
@iang_c2
|
||||
@iang_c3
|
||||
@iang_c4
|
||||
@iang_c5
|
||||
@iao_c1
|
||||
@iao_c2
|
||||
@iao_c3
|
||||
@iao_c4
|
||||
@iao_c5
|
||||
@ie_c1
|
||||
@ie_c2
|
||||
@ie_c3
|
||||
@ie_c4
|
||||
@ie_c5
|
||||
@ih0
|
||||
@ih1
|
||||
@ih2
|
||||
@ih_c1
|
||||
@ih_c2
|
||||
@ih_c3
|
||||
@ih_c4
|
||||
@ih_c5
|
||||
@ii_c1
|
||||
@ii_c2
|
||||
@ii_c3
|
||||
@ii_c4
|
||||
@ii_c5
|
||||
@in_c1
|
||||
@in_c2
|
||||
@in_c3
|
||||
@in_c4
|
||||
@in_c5
|
||||
@ing_c1
|
||||
@ing_c2
|
||||
@ing_c3
|
||||
@ing_c4
|
||||
@ing_c5
|
||||
@iong_c1
|
||||
@iong_c2
|
||||
@iong_c3
|
||||
@iong_c4
|
||||
@iong_c5
|
||||
@iou_c1
|
||||
@iou_c2
|
||||
@iou_c3
|
||||
@iou_c4
|
||||
@iou_c5
|
||||
@iy0
|
||||
@iy1
|
||||
@iy2
|
||||
@j_c
|
||||
@jh
|
||||
@k
|
||||
@k_c
|
||||
@l
|
||||
@l_c
|
||||
@m
|
||||
@m_c
|
||||
@n
|
||||
@n_c
|
||||
@ng
|
||||
@o_c1
|
||||
@o_c2
|
||||
@o_c3
|
||||
@o_c4
|
||||
@o_c5
|
||||
@ong_c1
|
||||
@ong_c2
|
||||
@ong_c3
|
||||
@ong_c4
|
||||
@ong_c5
|
||||
@ou_c1
|
||||
@ou_c2
|
||||
@ou_c3
|
||||
@ou_c4
|
||||
@ou_c5
|
||||
@ouh
|
||||
@ouj
|
||||
@oull
|
||||
@ouw
|
||||
@ow0
|
||||
@ow1
|
||||
@ow2
|
||||
@oy0
|
||||
@oy1
|
||||
@oy2
|
||||
@p
|
||||
@p_c
|
||||
@q_c
|
||||
@r
|
||||
@r_c
|
||||
@s
|
||||
@s_c
|
||||
@sh
|
||||
@sh_c
|
||||
@t
|
||||
@t_c
|
||||
@th
|
||||
@u_c1
|
||||
@u_c2
|
||||
@u_c3
|
||||
@u_c4
|
||||
@u_c5
|
||||
@ua_c1
|
||||
@ua_c2
|
||||
@ua_c3
|
||||
@ua_c4
|
||||
@ua_c5
|
||||
@uai_c1
|
||||
@uai_c2
|
||||
@uai_c3
|
||||
@uai_c4
|
||||
@uai_c5
|
||||
@uan_c1
|
||||
@uan_c2
|
||||
@uan_c3
|
||||
@uan_c4
|
||||
@uan_c5
|
||||
@uang_c1
|
||||
@uang_c2
|
||||
@uang_c3
|
||||
@uang_c4
|
||||
@uang_c5
|
||||
@uei_c1
|
||||
@uei_c2
|
||||
@uei_c3
|
||||
@uei_c4
|
||||
@uei_c5
|
||||
@uen_c1
|
||||
@uen_c2
|
||||
@uen_c3
|
||||
@uen_c4
|
||||
@uen_c5
|
||||
@uh0
|
||||
@uh1
|
||||
@uh2
|
||||
@uo_c1
|
||||
@uo_c2
|
||||
@uo_c3
|
||||
@uo_c4
|
||||
@uo_c5
|
||||
@uw0
|
||||
@uw1
|
||||
@uw2
|
||||
@v
|
||||
@v_c1
|
||||
@v_c2
|
||||
@v_c3
|
||||
@v_c4
|
||||
@v_c5
|
||||
@van_c1
|
||||
@van_c2
|
||||
@van_c3
|
||||
@van_c4
|
||||
@van_c5
|
||||
@ve_c1
|
||||
@ve_c2
|
||||
@ve_c3
|
||||
@ve_c4
|
||||
@ve_c5
|
||||
@vn_c1
|
||||
@vn_c2
|
||||
@vn_c3
|
||||
@vn_c4
|
||||
@vn_c5
|
||||
@w
|
||||
@w_c
|
||||
@xx_c
|
||||
@y
|
||||
@y_c
|
||||
@z
|
||||
@z_c
|
||||
@zh
|
||||
@zh_c
|
||||
@{
|
||||
@|
|
||||
@}
|
||||
@~
|
||||
@·
|
||||
@—
|
||||
@——
|
||||
@‘
|
||||
@’
|
||||
@“
|
||||
@”
|
||||
@…
|
||||
@……
|
||||
@‰
|
||||
@℃
|
||||
@∶
|
||||
@○
|
||||
@、
|
||||
@。
|
||||
@《
|
||||
@》
|
||||
@『
|
||||
@』
|
||||
@【
|
||||
@】
|
||||
@〔
|
||||
@〕
|
||||
@"
|
||||
@(
|
||||
@)
|
||||
@,
|
||||
@:
|
||||
@[
|
||||
@\
|
||||
@]
|
||||
@¥
|
||||
@ -1,31 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Union
|
||||
|
||||
|
||||
def build_tokenizer(
|
||||
token_type: str,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
space_symbol: str = "<space>",
|
||||
delimiter: str = None,
|
||||
g2p_type: str = None,
|
||||
p_word2phn: float = 0.5,
|
||||
):
|
||||
if "whisper_rich_ttsfrd" in token_type:
|
||||
from funasr.models.llm_asr.tts_text_tokenizer.whisper_tokenizer import WhisperRichTtsFrdTokenizer
|
||||
return WhisperRichTtsFrdTokenizer(
|
||||
token_path="multilingual_zh_ja_yue_char_del",
|
||||
num_languages=105,
|
||||
task=None,
|
||||
language=None,
|
||||
ttsfrd_type="ttsfrd_rich",
|
||||
ttsfrd_model=bpemodel,
|
||||
p_word2phn=p_word2phn,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
|
||||
)
|
||||
@ -1,175 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
import warnings
|
||||
import os
|
||||
import json
|
||||
import jamo
|
||||
|
||||
|
||||
class TtsFrdRich:
|
||||
"""
|
||||
rich text info: phoneme + puncs + boundary + [word2phone]
|
||||
"""
|
||||
def __init__(self, remove_boundary=True, token_type="pronplus"):
|
||||
super().__init__()
|
||||
self.remove_boundary = remove_boundary
|
||||
self.token_type = token_type
|
||||
self.g2p = None
|
||||
self.lang_type = None
|
||||
self.lang_type_map = {"zh-cn": "pinyin", "en-us": "enus"}
|
||||
|
||||
@staticmethod
|
||||
def contains_chinese(str):
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', str))
|
||||
@staticmethod
|
||||
def is_full_half_punctuation_string(s):
|
||||
# 包含ASCII标点和常见全角标点
|
||||
punctuation_pattern = r'[\u0000-\u002f\u003a-\u0040\u005b-\u0060\u007b-\u007f\u3000-\u303f\uff00-\uffef]'
|
||||
# 使用re.findall找出所有匹配的字符
|
||||
results = re.findall(punctuation_pattern, s)
|
||||
# 如果字符串长度和匹配到的字符总数一样,说明全部是标点
|
||||
return len(s) == len("".join(results))
|
||||
|
||||
def build(self, resource_dir, lang_type="Zh-CN"):
|
||||
lang_type = lang_type.lower()
|
||||
new_lang_type = self.lang_type_map[lang_type]
|
||||
if self.g2p is None:
|
||||
import ttsfrd
|
||||
assert os.path.isdir(resource_dir)
|
||||
fe = ttsfrd.TtsFrontendEngine()
|
||||
fe.initialize(resource_dir)
|
||||
self.g2p = fe
|
||||
# self.lang_type = new_lang_type
|
||||
self.set_lang_type(new_lang_type)
|
||||
if self.lang_type != new_lang_type:
|
||||
# self.lang_type = new_lang_type
|
||||
self.set_lang_type(new_lang_type)
|
||||
def set_lang_type(self, lang_type):
|
||||
if lang_type == "enus":
|
||||
self.g2p.set_lang_type(lang_type)
|
||||
self.g2p.enable_pinyin_mix(True)
|
||||
# self.g2p.set_breakmodel_index(0)
|
||||
else:
|
||||
self.g2p.set_lang_type(lang_type)
|
||||
self.g2p.enable_pinyin_mix(True)
|
||||
# self.g2p.set_breakmodel_index(1)
|
||||
self.lang_type = lang_type
|
||||
def set_token_type(self, token_type):
|
||||
assert token_type in ["pronplus", "word2phn", "wordlist"], token_type
|
||||
self.token_type = token_type
|
||||
def __call__(self, text) -> Union[List[str], str]:
|
||||
assert self.g2p is not None
|
||||
if not self.contains_chinese(text):
|
||||
if self.lang_type != "enus":
|
||||
self.set_lang_type("enus")
|
||||
else:
|
||||
if self.lang_type != "pinyin":
|
||||
self.set_lang_type("pinyin")
|
||||
if self.token_type == "word2phn":
|
||||
return self._get_word2phn(text)
|
||||
elif self.token_type == "pronplus":
|
||||
return self._get_pronplus(text)
|
||||
elif self.token_type == "wordlist":
|
||||
return self._get_wordlist(text)
|
||||
else:
|
||||
raise ValueError(f"only type: [pronplus, word2phn, wordlist] supported, now type: {self.token_type}")
|
||||
def _get_pronplus(self, text) -> List[str]:
|
||||
pronplus = self.g2p.get_frd_extra_info(text, 'pronplus')
|
||||
if self.remove_boundary:
|
||||
pronplus = pronplus.replace("/", "") # word boundary
|
||||
pronplus = pronplus.replace("#", "") # syllable boundary
|
||||
# pronplus = pronplus.replace("\n", "")
|
||||
pronplus = pronplus.replace("\n", " ")
|
||||
symbols: List[str] = []
|
||||
for pron in pronplus.split(" "):
|
||||
pron = pron.strip().lower()
|
||||
if pron and pron[0].isalpha():
|
||||
symbols.append(pron)
|
||||
else:
|
||||
symbols.extend([mark for mark in pron if mark])
|
||||
return symbols
|
||||
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
json_str = self._get_word2phn(line)
|
||||
data = json.loads(json_str)
|
||||
retval = []
|
||||
for one in data["word2phn"]:
|
||||
for key, value in one.items():
|
||||
if value is not None:
|
||||
retval.extend([f"@{x}" for x in value])
|
||||
else:
|
||||
if key == " ":
|
||||
key = "<|space|>"
|
||||
retval.append(f"@{key}")
|
||||
|
||||
return retval
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
pass
|
||||
|
||||
def _get_wordlist(self, text) -> str:
|
||||
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
|
||||
return wordlist
|
||||
def _get_word2phn(self, text) -> str:
|
||||
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
|
||||
wordlist_subs = wordlist.split("\n")
|
||||
word2phn_info = []
|
||||
prev_word_type = None
|
||||
prev_word = None
|
||||
for json_str in wordlist_subs:
|
||||
if len(json_str) == 0:
|
||||
continue
|
||||
wordlist_info = json.loads(json_str)["wordlist"]
|
||||
for word_info in wordlist_info:
|
||||
is_english_word = True
|
||||
this_phone_list = None
|
||||
if word_info["syllables"] is None:
|
||||
# punctuation
|
||||
this_word_type = "punc"
|
||||
pass
|
||||
elif self.is_full_half_punctuation_string(word_info["name"]):
|
||||
# punctuation, handle some g2p's mistakes spelling punctuation!!!
|
||||
this_word_type = "punc"
|
||||
pass
|
||||
else:
|
||||
this_phone_list = []
|
||||
for syllable_info in word_info["syllables"]:
|
||||
phn_count = syllable_info["phone_count"]
|
||||
syllable_phone_list = syllable_info["pron_text"].split(" ")
|
||||
assert len(syllable_phone_list) == phn_count, len(syllable_phone_list)
|
||||
if "py_text" in syllable_info:
|
||||
# chinese add tone info
|
||||
syllable_phone_list[-1] = syllable_phone_list[-1]+str(syllable_info["tone"])
|
||||
is_english_word = False
|
||||
this_phone_list += syllable_phone_list
|
||||
if is_english_word:
|
||||
this_word_type = "en_word"
|
||||
else:
|
||||
this_word_type = "ch_word"
|
||||
if this_word_type == "en_word":
|
||||
if prev_word_type is None:
|
||||
pass
|
||||
elif prev_word_type == "en_word":
|
||||
word2phn_info.append({" ": None})
|
||||
elif prev_word_type == "punc":
|
||||
if (prev_word not in ["\"", "\'", "(", "(", "[", "【"] and
|
||||
prev_word.split(" ")[-1] not in ["\"", "\'", "(", "(", "[", "【"]):
|
||||
word2phn_info.append({" ": None})
|
||||
elif prev_word_type == "ch_word":
|
||||
word2phn_info.append({" ": None})
|
||||
elif this_word_type == "ch_word":
|
||||
if prev_word_type is not None and prev_word_type == "en_word":
|
||||
word2phn_info.append({" ": None})
|
||||
elif this_word_type == "punc":
|
||||
if word_info["name"] in ["("]:
|
||||
word2phn_info.append({" ": None})
|
||||
this_word2phn_dict = {word_info["name"]: this_phone_list}
|
||||
word2phn_info.append(this_word2phn_dict)
|
||||
prev_word_type = this_word_type
|
||||
prev_word = list(word2phn_info[-1].keys())[0]
|
||||
return json.dumps({"raw": text, "word2phn": word2phn_info}, ensure_ascii=False)
|
||||
@ -1,462 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"minnan": "minnan",
|
||||
"wuyu": "wuyu",
|
||||
"dialect": "dialect",
|
||||
"zh/en": "zh/en",
|
||||
"en/zh": "en/zh",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
AUDIO_EVENT = {
|
||||
"ASR": "ASR",
|
||||
"AED": "AED",
|
||||
"SER": "SER",
|
||||
"Speech": "Speech",
|
||||
"/Speech": "/Speech",
|
||||
"BGM": "BGM",
|
||||
"/BGM": "/BGM",
|
||||
"Laughter": "Laughter",
|
||||
"/Laughter": "/Laughter",
|
||||
"Applause": "Applause",
|
||||
"/Applause": "/Applause",
|
||||
}
|
||||
|
||||
EMOTION = {
|
||||
"HAPPY": "HAPPY",
|
||||
"SAD": "SAD",
|
||||
"ANGRY": "ANGRY",
|
||||
"NEUTRAL": "NEUTRAL",
|
||||
}
|
||||
TTS_Vocal_Token = {
|
||||
"TTS/B": "TTS/B",
|
||||
"TTS/O": "TTS/O",
|
||||
"TTS/Q": "TTS/Q",
|
||||
"TTS/A": "TTS/A",
|
||||
"TTS/CO": "TTS/CO",
|
||||
"TTS/CL": "TTS/CL",
|
||||
"TTS/H": "TTS/H",
|
||||
"endofprompt": "endofprompt",
|
||||
"sil": "sil",
|
||||
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(3, 14)}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
encoding: tiktoken.Encoding
|
||||
num_languages: int
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.encoding.n_vocab
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.encoding.eot_token
|
||||
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
return self.to_language_token(self.language)
|
||||
|
||||
def to_language_token(self, language):
|
||||
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)[: self.num_languages]
|
||||
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
||||
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99, ttsfrd_name: Optional[str] = None):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
if name == "gpt2" or name == "multilingual":
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
else:
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
||||
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
||||
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
if ttsfrd_name is not None:
|
||||
ttsfrd_vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{ttsfrd_name}.token")
|
||||
assert os.path.isfile(ttsfrd_vocab_path), f"{ttsfrd_vocab_path} missing"
|
||||
with open(ttsfrd_vocab_path, "r") as fr:
|
||||
specials.extend([f"<|{line.strip()}|>" for line in fr if line])
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
encoding_path: Optional[str] = None,
|
||||
ttsfrd_name: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
if encoding_path is not None:
|
||||
encoding_name = encoding_path
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages, ttsfrd_name=ttsfrd_name)
|
||||
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||
@ -1,164 +0,0 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Iterable, List, Union
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WhisperRichTtsFrdTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
token_path: str,
|
||||
num_languages: int,
|
||||
task: str = None,
|
||||
language: str = None,
|
||||
ttsfrd_type: str = None,
|
||||
p_word2phn: float = 0.5,
|
||||
ttsfrd_model: str = None,
|
||||
):
|
||||
import funasr.models.llm_asr.tts_text_tokenizer.voice_echo_rich_tokenizer as tokenizer
|
||||
|
||||
self.token_path = token_path
|
||||
self.num_languages = num_languages
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.ttsfrd_type = ttsfrd_type
|
||||
self.p_word2phn = p_word2phn
|
||||
# print('token_path:',token_path)
|
||||
if token_path == "whisper_en" or token_path == "whisper_gpt2" or token_path == "gpt2":
|
||||
self.tokenizer = tokenizer.get_tokenizer(multilingual=False, num_languages=num_languages)
|
||||
elif token_path == "whisper_multilingual" or token_path == "multilingual":
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=True, language=self.language, task=self.task, num_languages=num_languages
|
||||
)
|
||||
else:#
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=True, language=self.language, task=self.task, num_languages=num_languages,
|
||||
encoding_path=token_path, ttsfrd_name=ttsfrd_type
|
||||
)
|
||||
if ttsfrd_model is not None and os.path.isdir(ttsfrd_model):
|
||||
from funasr.models.llm_asr.tts_text_tokenizer.phoneme_tokenizer import TtsFrdRich
|
||||
self.ttsfrd_tokenizer = TtsFrdRich(remove_boundary=True, token_type="word2phn")
|
||||
self.ttsfrd_tokenizer.build(ttsfrd_model)
|
||||
else:
|
||||
self.ttsfrd_tokenizer = None
|
||||
# self.tokenizer = copy.deepcopy(self.tokenizer)
|
||||
|
||||
|
||||
def text_mixing(self, line: str) -> str:
|
||||
try:
|
||||
data_info = json.loads(line)
|
||||
# ttsfrd_word2phn info
|
||||
if isinstance(data_info, dict) and "raw" in data_info and "word2phn" in data_info:
|
||||
raw_text = data_info["raw"]
|
||||
ttsfrd_word2phn = data_info["word2phn"]
|
||||
if random.random() < self.p_word2phn:
|
||||
ret_text = ""
|
||||
for ttsfrd_word in ttsfrd_word2phn:
|
||||
for word_str, phn_list in ttsfrd_word.items():
|
||||
if phn_list is not None:
|
||||
if random.random() < self.p_word2phn:
|
||||
ret_text = ret_text + "".join([f"<|@{p}|>" for p in phn_list])
|
||||
else:
|
||||
ret_text += word_str
|
||||
else:
|
||||
ret_text += word_str
|
||||
else:
|
||||
ret_text = raw_text
|
||||
else:
|
||||
ret_text = line
|
||||
except json.JSONDecodeError:
|
||||
ret_text = line
|
||||
return ret_text
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
def text2ids(self, line: str, language: str) -> List[int]:
|
||||
language_tok = "<|" + language + "|>"
|
||||
assert language_tok in self.tokenizer.special_tokens, "Language token not found, lang: {}, line: {}".format(language_tok, line)
|
||||
# line = re.sub(r'<(\d+\.\d+)>', r'<|\1|>', line)
|
||||
pattern = re.compile(r'<|(\d+\.\d+)|>')
|
||||
with_timestamps = pattern.search(line)
|
||||
if with_timestamps:
|
||||
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
|
||||
allowed_special = set([f"<|{i * 0.02:.2f}|>" for i in range(1501)])
|
||||
encoded_line = self.tokenizer.encode(line, allowed_special=allowed_special)
|
||||
else:
|
||||
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
|
||||
encoded_line = self.tokenizer.encode(line)
|
||||
return sot_tok + encoded_line
|
||||
|
||||
def ids2text(self, integers: Union[np.ndarray, Iterable[int]]) -> str:
|
||||
return self.tokenizer.decode_with_timestamps(integers)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
return [self.tokenizer.decode_with_timestamps([i]) for i in integers]
|
||||
|
||||
def text2tokens(self, line: str, endofprompt="<|endofprompt|>", sil="<|sil|>") -> List[str]:
|
||||
# keep prompt and sil unchanged
|
||||
prompt_text = ""
|
||||
st_sil, ed_sil = False, False
|
||||
if endofprompt in line:
|
||||
pos = line.find(endofprompt)
|
||||
prompt_text = line[:pos+len(endofprompt)]
|
||||
line = line[pos+len(endofprompt):]
|
||||
if line.startswith(sil):
|
||||
line = line[len(sil):]
|
||||
st_sil = True
|
||||
if line.endswith(sil):
|
||||
line = line[:-len(sil)]
|
||||
ed_sil = True
|
||||
|
||||
# token to phone and mixup
|
||||
if self.ttsfrd_tokenizer is not None:
|
||||
line = self.ttsfrd_tokenizer(line)
|
||||
if self.ttsfrd_type is not None:
|
||||
line = self.text_mixing(line)
|
||||
|
||||
# add prompt text and sil back
|
||||
if st_sil:
|
||||
line = sil + line
|
||||
if ed_sil:
|
||||
line = line + sil
|
||||
line = prompt_text + line
|
||||
|
||||
return self.tokenizer.encode(line, allowed_special="all")
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
return self.tokenizer.decode_with_timestamps(tokens)
|
||||
|
||||
# def get_sot(self, sot_template: str, lang: str = None) -> List[int]:
|
||||
# if lang is not None:
|
||||
# lang = lang.replace("<", "").replace(">", "").replace("|", "")
|
||||
# sot = sot_template.replace("LANG", lang)
|
||||
# else:
|
||||
# if "<|LANG|>" in sot_template:
|
||||
# sot = sot_template.split("<|LANG|>", 1)[0]
|
||||
# else:
|
||||
# sot = sot_template
|
||||
# sot_tok = self.tokenizer.encode(sot, allowed_special="all")
|
||||
# return sot_tok
|
||||
|
||||
def get_sot(self, language: str = None, with_timestamps: bool = False) -> List[int]:
|
||||
if language is not None:
|
||||
language_tok = "<|" + language + "|>"
|
||||
assert language_tok in self.tokenizer.special_tokens
|
||||
if with_timestamps:
|
||||
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
|
||||
else:
|
||||
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
|
||||
else:
|
||||
sot_tok = [self.tokenizer.sot]
|
||||
return sot_tok
|
||||
|
||||
def get_all_languages(self) -> List[str]:
|
||||
return list(self.tokenizer.all_language_codes)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(model_type={self.token_path}, "
|
||||
f"language={self.language}, ttsfrd={self.ttsfrd_type})"
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user