This commit is contained in:
游雁 2024-10-14 13:58:35 +08:00
parent 5abb8367a3
commit 62c6f50a1d
16 changed files with 627 additions and 128890 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,365 +0,0 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torch.nn.functional as F
from torch import nn
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x):
return statistics_pooling(x)
class TDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
super(TDNNLayer, self).__init__()
if padding < 0:
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
x = self.nonlinear(x)
return x
class CAMLayer(nn.Module):
def __init__(self,
bn_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
bias,
reduction=2):
super(CAMLayer, self).__init__()
self.linear_local = nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.linear_local(x)
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
context = self.relu(self.linear1(context))
m = self.sigmoid(self.linear2(context))
return y * m
def seg_pooling(self, x, seg_len=100, stype='avg'):
if stype == 'avg':
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
elif stype == 'max':
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
else:
raise ValueError('Wrong segment pooling type.')
shape = seg.shape
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
seg = seg[..., :x.shape[-1]]
return seg
class CAMDenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
x = self.bn_function(x)
x = self.cam_layer(self.nonlinear2(x))
return x
class CAMDenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class TransitLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True,
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
x = self.linear(x)
return x
class DenseLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class FCM(nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
class CAMPPlus(nn.Module):
def __init__(self,
feat_dim=80,
embedding_size=512,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True):
super(CAMPPlus, self).__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
self.xvector = nn.Sequential(
OrderedDict([
('tdnn',
TDNNLayer(channels,
init_channels,
5,
stride=2,
dilation=1,
padding=-1,
config_str=config_str)),
]))
channels = init_channels
for i, (num_layers, kernel_size,
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
block = CAMDenseTDNNBlock(num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
'transit%d' % (i + 1),
TransitLayer(channels,
channels // 2,
bias=False,
config_str=config_str))
channels //= 2
self.xvector.add_module(
'out_nonlinear', get_nonlinear(config_str, channels))
self.xvector.add_module('stats', StatsPool())
self.xvector.add_module(
'dense',
DenseLayer(channels * 2, embedding_size, config_str='prelu'))
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
return x

View File

@ -1,258 +0,0 @@
import torch
import logging
import torch.nn.functional as F
class CTC(torch.nn.Module):
"""CTC module.
Args:
odim: dimension of outputs
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
ctc_type: builtin or warpctc
reduce: reduce the CTC loss into a scalar
"""
def __init__(
self,
odim: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
length_normalize: str = None,
):
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
self.length_normalize = length_normalize
if self.ctc_type == "builtin":
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
else:
raise ValueError(
f'ctc_type must be "builtin": {self.ctc_type}'
)
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
if self.ctc_type == "builtin":
th_pred = th_pred.log_softmax(2)
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if loss.requires_grad and self.ignore_nan_grad:
# ctc_grad: (L, B, O)
ctc_grad = loss.grad_fn(torch.ones_like(loss))
ctc_grad = ctc_grad.sum([0, 2])
indices = torch.isfinite(ctc_grad)
size = indices.long().sum()
if size == 0:
# Return as is
logging.warning(
"All samples in this mini-batch got nan grad."
" Returning nan value instead of CTC loss"
)
elif size != th_pred.size(1):
logging.warning(
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
" samples got nan grad."
" These were ignored for CTC loss."
)
# Create mask for target
target_mask = torch.full(
[th_target.size(0)],
1,
dtype=torch.bool,
device=th_target.device,
)
s = 0
for ind, le in enumerate(th_olen):
if not indices[ind]:
target_mask[s : s + le] = 0
s += le
# Calc loss again using maksed data
loss = self.ctc_loss(
th_pred[:, indices, :],
th_target[target_mask],
th_ilen[indices],
th_olen[indices],
)
th_ilen, th_olen = th_ilen[indices], th_olen[indices]
else:
size = th_pred.size(1)
if self.length_normalize is not None:
if self.length_normalize == "olen":
loss = loss / th_olen
else:
loss = loss / th_ilen
if self.reduce:
# Batch-size average
loss = loss.sum() / size
else:
loss = loss / size
return loss
elif self.ctc_type == "warpctc":
# warpctc only supports float32
th_pred = th_pred.to(dtype=torch.float32)
th_target = th_target.cpu().int()
th_ilen = th_ilen.cpu().int()
th_olen = th_olen.cpu().int()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc
# return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
loss = loss.sum()
return loss
elif self.ctc_type == "gtnctc":
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
if self.ctc_type == "gtnctc":
# gtn expects list form for ys
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
else:
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat = ys_hat.transpose(0, 1)
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
return loss
def softmax(self, hs_pad):
"""softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
return F.softmax(self.ctc_lo(hs_pad), dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
def ctc_forced_align(
log_probs: torch.Tensor,
targets: torch.Tensor,
input_lengths: torch.Tensor,
target_lengths: torch.Tensor,
blank: int = 0,
ignore_id: int = -1,
) -> torch.Tensor:
"""Align a CTC label sequence to an emission.
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
"""
targets[targets == ignore_id] = blank
batch_size, input_time_size, _ = log_probs.size()
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
_t_a_r_g_e_t_s_ = torch.cat(
(
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
torch.full_like(targets[:, :1], blank),
),
dim=-1,
)
diff_labels = torch.cat(
(
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
),
dim=1,
)
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
padding_num = 2
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
for t in range(1, input_time_size):
prev = torch.stack(
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
)
prev_max_value, prev_max_idx = prev.max(dim=0)
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
backpointers[:, t, padding_num:] = prev_max_idx
l1l2 = best_score.gather(
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
)
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
for t in range(input_time_size - 1, 0, -1):
target_indices = path[:, t]
prev_max_idx = backpointers[bsz_indices, t, target_indices]
path[:, t - 1] += target_indices - prev_max_idx
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
return alignments

File diff suppressed because it is too large Load Diff

View File

@ -1,542 +0,0 @@
import logging
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import nn
from funasr.models.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.models.transformer.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.utils.nets_utils import get_activation
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.mask import subsequent_mask, causal_block_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError,
check_short_utt, Conv2dSubsamplingPad
)
import torch.nn.functional as F
from funasr.models.llm_asr.conformer_encoder import ConvolutionModule, EncoderLayer
from funasr.models.ctc.ctc import CTC
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True,
out_channels=None, name="conv", channel_first=True, stride=2, causal=False):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.channel_first = channel_first
self.stride = stride
self.causal = causal
self.conv = None
if use_conv_transpose:
# transpose conv doesn't support causal mode.
assert not causal
kernel_size = stride*2 + stride % 2
padding = (kernel_size - stride) // 2
self.conv = nn.ConvTranspose1d(channels, self.out_channels, kernel_size, stride, padding)
elif use_conv:
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(
self.channels, self.out_channels, stride*2+1, stride=1,
padding=0,
)
def forward(self, inputs, input_lengths=None):
if not self.channel_first:
inputs = inputs.transpose(1, 2).contiguous()
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
outputs = self.conv(inputs)
if not self.channel_first:
outputs = outputs.transpose(1, 2).contiguous()
return outputs, input_lengths * self.stride
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
if self.use_conv:
if not self.causal:
outputs = F.pad(outputs, (self.stride, self.stride))
else:
outputs = F.pad(outputs, (self.stride*2, 0))
outputs = self.conv(outputs)
if not self.channel_first:
outputs = outputs.transpose(1, 2).contiguous()
return outputs, input_lengths * self.stride
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len:int = 1):
super().__init__()
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
kernel_size=pre_lookahead_len+1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
kernel_size=3, stride=1, padding=0,
)
def forward(self, inputs, ilens):
"""
inputs: (batch_size, seq_len, channels)
"""
outputs = inputs.transpose(1, 2).contiguous()
# look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
outputs = F.leaky_relu(self.conv1(outputs))
# outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
# residual connection
outputs = (outputs + inputs) * mask
return outputs, ilens
class UpsampleConformerEncoder(nn.Module):
"""Progressive upsampling Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
upsample_blocks: int = 3,
upsample_attn_layers: int = 2,
upsample_ratios: tuple = None,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
causal: bool = False,
skip: bool = False,
channel_first: bool = False,
use_causal_prob: float = None,
pre_lookahead_len: int = None,
):
super().__init__()
self._output_size = output_size
self.causal = causal
self.skip = skip
self.channel_first = channel_first
self.pre_lookahead_len = pre_lookahead_len
self.use_causal_prob = use_causal_prob
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if selfattention_layer_type == "rel_selfattn":
selfattention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert selfattention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2dpad":
self.embed = Conv2dSubsamplingPad(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if pre_lookahead_len is not None:
self.pre_lookahead_layer = PreLookaheadLayer(output_size, pre_lookahead_len)
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate=0.0,
),
)
self.upsample_blocks = nn.ModuleList()
if upsample_ratios is None:
upsample_ratios = [2] * upsample_blocks
self.upsample_ratios = upsample_ratios
assert upsample_blocks == len(upsample_ratios)
for i in range(upsample_blocks):
if not causal:
upsample_conv_block = Upsample1D(
channels=output_size, use_conv=False, use_conv_transpose=True,
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=False,
)
else:
upsample_conv_block = Upsample1D(
channels=output_size, use_conv=True, use_conv_transpose=False,
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=True,
)
upsample_attn_block = repeat(
upsample_attn_layers,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate=0.0,
),
)
attn_input_layer = torch.nn.Sequential(
torch.nn.Linear(output_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
self.upsample_blocks.append(nn.ModuleList([upsample_conv_block, attn_input_layer, upsample_attn_block]))
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def rand_mix_masks(self, causal, noncausal):
use_causal = (torch.rand([causal.shape[0], 1, 1]) <= self.uni_encoder_prob).to(causal)
masks = use_causal * causal + (1 - use_causal) * noncausal
return masks
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
raw_input = xs_pad
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if ilens is not None:
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
else:
masks = torch.ones(
xs_pad.shape[0], 1, xs_pad.shape[1],
dtype=torch.bool, device=xs_pad.device
)
if self.use_causal_prob is not None:
use_causal = (torch.rand([xs_pad.shape[0], 1, 1]) <= self.use_causal_prob).to(xs_pad)
else:
use_causal = torch.ones([xs_pad.shape[0], 1, 1]).to(xs_pad)
if self.causal:
causal_mask = subsequent_mask(
xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype
).unsqueeze(0)
causal_mask = masks & causal_mask
# whether to train causal & non-causal in a single model
masks = use_causal * causal_mask + (1 - use_causal) * masks
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
or isinstance(self.embed, Conv2dSubsamplingPad)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
if self.pre_lookahead_len is not None:
xs = xs_pad
if isinstance(xs_pad, tuple):
xs = xs_pad[0]
xs, _ = self.pre_lookahead_layer(xs, ilens)
if isinstance(xs_pad, tuple):
xs_pad = (xs, xs_pad[1])
# 1. modeling on inputs
intermediate_outs = []
xs_pad, masks = self.encoders(xs_pad, masks)
# 2. progressive upsampling
outs, olens = xs_pad, ilens
total_ratio = 1
for up_ratio, layer in zip(self.upsample_ratios, self.upsample_blocks):
up_layer, attn_input_layer, attn_layer = layer
if isinstance(outs, tuple):
outs = outs[0]
outs, olens = up_layer(outs, olens)
masks = (~make_pad_mask(olens)[:, None, :]).to(outs.device)
total_ratio = total_ratio * up_ratio
if self.causal:
causal_mask = causal_block_mask(
outs.shape[1], total_ratio, device=outs.device, dtype=masks.dtype
).unsqueeze(0)
causal_mask = masks & causal_mask
masks = use_causal * causal_mask + (1 - use_causal) * masks
outs = attn_input_layer(outs)
outs, _ = attn_layer(outs, masks)
xs_pad = outs
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if self.skip:
xs_pad = xs_pad + raw_input
# olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
if ilens is not None:
return xs_pad, olens, None
else:
return xs_pad

View File

@ -1,625 +0,0 @@
import logging
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import torch.nn.functional as F
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.llm_asr.label_smoothing_loss import LabelSmoothingLoss
from copy import deepcopy
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import pad_list
import random
import numpy as np
from funasr.utils.hinter import hint_once
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.llm_asr.tts_models.ctc_alignment import ctc_forced_align
from torch.nn.utils.rnn import pad_sequence
import itertools
from distutils.version import LooseVersion
from contextlib import contextmanager
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class NARCTCModel(nn.Module):
def __init__(
self,
input_size: int,
vocab_size: int,
encoder: Union[nn.Module, dict],
decoder: Optional[nn.Module] = None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.decoder = decoder
self.encoder = encoder if isinstance(encoder, nn.Module) else self.build_encoder(encoder)
self.output_size = self.encoder.output_size()
self.ignore_id = ignore_id
self.vocab_size = vocab_size
self.ctc_weight = ctc_weight
# build ctc module
from funasr.models.llm_asr.tts_models.ctc_alignment import CTC
ctc_conf = kwargs.pop("ctc_conf", {})
self.ctc = CTC(vocab_size, encoder_output_size=self.output_size, **ctc_conf)
self.text_embedding = torch.nn.Embedding(self.vocab_size, input_size)
self.token_embedding = torch.nn.Embedding(vocab_size, input_size)
xvec_size = kwargs.get("xvec_size", None)
if xvec_size is not None:
self.xvec_proj = torch.nn.Linear(xvec_size, input_size)
else:
self.xvec_proj = None
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.sos = vocab_size - 2
self.eos = vocab_size - 1
self.length_regulator_conf = kwargs.get("length_regulator_conf", None)
if self.length_regulator_conf is not None:
self.length_regulator = self.build_length_regulator()
else:
self.length_regulator = None
def build_encoder(self, encoder_conf: dict):
if encoder_conf is None:
assert hasattr(self, "encoder_conf"), \
"function param encoder_conf is None and model doesn't has encoder_conf attribute either."
encoder_conf = self.encoder_conf
encoder_name = encoder_conf.pop("name", "transformer")
model = None
if encoder_name == "transformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
use_cnn_module=False,
macaron_style=False,
)
elif encoder_name == "conformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
)
elif encoder_name == "upsampling_conformer":
from funasr.models.llm_asr.tts_models.encoders import UpsampleConformerEncoder
model = UpsampleConformerEncoder(
**encoder_conf,
input_size=self.input_size,
)
encoder_conf["name"] = encoder_name
return model
def build_length_regulator(self):
name = self.length_regulator_conf.pop("name", None)
model = None
if name == "upsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator
model = UpSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "downsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator
model = DownSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "interpolate":
from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator
model = InterpolateRegulator(self.input_size, **self.length_regulator_conf)
elif name == "upsampling_cif":
from funasr.models.llm_asr.diffusion_models.length_regulator import UpsamplingCifRegulator
model = UpsamplingCifRegulator(self.input_size, **self.length_regulator_conf)
self.length_regulator_conf["name"] = name
return model
@staticmethod
def norm_and_sample_xvec(xvec, xvec_lengths):
xvec_list = []
for i, ilen in enumerate(xvec_lengths):
idx = random.randint(0, ilen - 1)
while torch.any(~torch.isfinite(xvec[i, idx])):
idx = random.randint(0, ilen - 1)
xvec_list.append(xvec[i, idx])
rand_xvec = torch.vstack(xvec_list)
rand_xvec = F.normalize(rand_xvec, dim=1)
return rand_xvec
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.decoder.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def model_forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
# 0. Up-sampling text length
if self.length_regulator is not None:
text, text_lengths = self.length_regulator(text, text_lengths)
# 1. padding xvec
if xvec is not None and self.xvec_proj is not None:
xvec = xvec[:, :xvec_lengths.max()]
# random select a xvec from xvec matrix
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
xvec = self.xvec_proj(xvec)
text = text + xvec.unsqueeze(1)
hint_once("use xvec", "use_xvec")
# 1. Encoder
encoder_out, encoder_out_lens, _ = self.encoder(text, text_lengths)
return encoder_out, encoder_out_lens
def predictor(
self,
am: torch.Tensor,
am_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
alignment,
):
acoustic_embeds = []
use_pred_num = 0
for am_xs, enc_len, ali, y, y_lens in zip(am, am_lens, alignment, ys_pad, ys_pad_lens):
pred = itertools.groupby(ali[:enc_len])
acoustic_embed = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0:
acoustic_embed.append(torch.mean(am_xs[_start:_end, :], 0, keepdim=True))
_start = _end
if len(acoustic_embed) != y_lens:
acoustic_embeds.append(y[:y_lens])
else:
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
use_pred_num += 1
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
return acoustic_embeds, use_pred_num / am.shape[0]
def force_align_text(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs
)
log_probs = self.ctc.log_softmax(encoder_out)
with torch.no_grad():
alignment = ctc_forced_align(
log_probs.float(),
speech.long(),
encoder_out_lens.long(),
speech_lengths.long(),
ignore_id=self.ignore_id,
)
aligned_token_emb, use_pred_ratio = self.predictor(
encoder_out, encoder_out_lens,
self.token_embedding(speech), speech_lengths,
alignment,
)
loss = 0
states = dict(
use_pred_ratio=use_pred_ratio,
)
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_ctc"] = loss_ctc.item()
loss = loss + self.ctc_weight * loss_ctc
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_att"] = loss_att.item()
loss = loss + (1.0 - self.ctc_weight) * loss_att
states["loss"] = loss.item()
return loss, aligned_token_emb, states
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
logits = self.ctc.log_softmax(encoder_out)
return loss_ctc, logits
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...), speech tokens
speech_lengths: (Batch, )
text: (Batch, Length), text tokens
text_lengths: (Batch, )
xvec: (Batch, Length, ...) x-vectors
xvec_lengths: (Batch, )
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, : speech_lengths.max()]
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs,
)
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict(
batch_size=float(batch_size),
text_len=float(text.shape[1]),
enc_len=float(encoder_out.shape[1]),
speech_len=float(speech.shape[1]),
token_text_ratio=float(speech.shape[1])/float(text.shape[1]),
)
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def topp_sampling(self, probs, top_p=0.8):
sorted_value, sorted_idx = probs.sort(descending=True, stable=True)
cumulative_probs = torch.cumsum(sorted_value, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_idx[sorted_indices_to_remove]
probs[indices_to_remove] = 0
top_ids = torch.multinomial(probs, num_samples=1)
return top_ids
def sampling_ids(self, enc_outs, sampling="greedy", blank_penalty=None, return_probs=False):
probs = self.ctc.softmax(enc_outs)
if blank_penalty > 0:
probs[:, :, 0] = probs[:, :, 0] * blank_penalty
# top-p sampling
if "." in sampling:
sampling = float(sampling)
tokens = self.topp_sampling(probs, top_p=sampling)
tokens = torch.tensor(tokens, dtype=torch.long).to(probs.device)
# top-k sampling
elif sampling.isdigit():
sampling = int(sampling)
probs = probs.topk(sampling)
tokens = probs.multinomial(1, replacement=True)
else:
if sampling == "greedy":
tokens = torch.argmax(probs, dim=-1)
elif "threshold_" in sampling:
threshold = float(sampling.split("_")[1])
hint_once(f"Decoding mode: blank threshold={threshold:.2f}", "decoding_mode")
# mask out blank according to threshold
mask = probs[:, :, 0] > threshold
probs[:, :, 0] = probs[:, :, 0] * mask
tokens = torch.argmax(probs, dim=-1)
else:
raise NotImplementedError(f"sampling method {sampling} not implemented")
if not return_probs:
return tokens
return tokens, probs
def inference(
self,
text: torch.Tensor, text_lengths: torch.Tensor,
xvec=None, xvec_lengths=None,
sampling="greedy",
blank_penalty: float = 0.0,
text_is_embedding=False,
return_hidden=False,
**kwargs,
):
device = text.device
# use casual mode at inference stage
self.encoder.use_causal_prob = kwargs.get("use_causal_prob", 1.0)
hint_once(f"use_causal_prob {self.encoder.use_causal_prob}.", "use_causal_prob")
# embed text inputs
if not text_is_embedding:
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
# 1. Encoder
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
)
fa_tokens, enc_probs = self.sampling_ids(
encoder_out,
sampling=sampling,
blank_penalty=blank_penalty,
return_probs=True,
)
reduced_fa_tokens = []
for pred_token, pred_frame in itertools.groupby(fa_tokens[0].cpu().tolist()):
if pred_token != 0:
reduced_fa_tokens.append(pred_token)
else:
reduced_fa_tokens.extend(list(pred_frame))
fa_tokens = torch.tensor([reduced_fa_tokens]).to(fa_tokens)
# remove blanks (id=0) and convert token ids into the original format
tokens = [[x-1] for x in fa_tokens[0].cpu().tolist() if x > 0]
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
if not return_hidden:
return tokens
acoustic_embs, acoustic_emb_lens = [], []
for idx, (prob, enc) in enumerate(zip(enc_probs, encoder_out)):
pred = itertools.groupby(prob.argmax(-1).cpu())
acs_emb = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0 and pred_token != -1:
acs_emb.append(torch.mean(enc[_start:_end, :], 0, keepdim=True))
_start = _end
acs_emb = torch.cat(acs_emb, dim=0)
acoustic_embs.append(acs_emb)
acoustic_emb_lens.append(acs_emb.shape[0])
acoustic_embs = pad_list(acoustic_embs, 0.0)
acoustic_emb_lens = torch.tensor(acoustic_emb_lens, dtype=torch.int64, device=device)
return (tokens, fa_tokens), acoustic_embs, acoustic_emb_lens
class NARCTCProbModel(NARCTCModel):
def __init__(self, input_size: int, vocab_size: int, encoder: Union[nn.Module, dict],
decoder: Optional[nn.Module] = None, ctc_weight: float = 0.5, ignore_id: int = -1,
lsm_weight: float = 0.0, length_normalized_loss: bool = False, **kwargs):
super().__init__(input_size, vocab_size, encoder, decoder, ctc_weight, ignore_id, lsm_weight,
length_normalized_loss, **kwargs)
def predictor(
self,
am_probs: torch.Tensor,
am_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
alignment,
):
acoustic_embeds = []
use_pred_num = 0
for probs, enc_len, ali, y, y_lens in zip(am_probs, am_lens, alignment, ys_pad, ys_pad_lens):
pred = itertools.groupby(ali[:enc_len])
acoustic_embed = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0:
acoustic_embed.append(torch.mean(probs[_start:_end, :], 0, keepdim=True))
_start = _end
if len(acoustic_embed) != y_lens:
acoustic_embeds.append(F.one_hot(y[:y_lens], self.vocab_size).float())
else:
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
use_pred_num += 1
acoustic_embeds[-1] = torch.matmul(acoustic_embeds[-1], self.token_embedding.weight)
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
return acoustic_embeds, use_pred_num / am_probs.shape[0]
def force_align_text(self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs
)
log_probs = self.ctc.log_softmax(encoder_out)
with torch.no_grad():
alignment = ctc_forced_align(
log_probs.float(),
speech.long(),
encoder_out_lens.long(),
speech_lengths.long(),
ignore_id=self.ignore_id,
)
aligned_token_emb, use_pred_ratio = self.predictor(
log_probs.float(), encoder_out_lens.long(),
speech.long(), speech_lengths.long(),
alignment,
)
loss = 0
states = dict(
use_pred_ratio=use_pred_ratio,
)
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_ctc"] = loss_ctc.item()
loss = loss + self.ctc_weight * loss_ctc
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_att"] = loss_att.item()
loss = loss + (1.0 - self.ctc_weight) * loss_att
states["loss"] = loss.item()
return loss, aligned_token_emb, states
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor, xvec=None, xvec_lengths=None, sampling="greedy",
blank_penalty: float = 0.0, text_is_embedding=False, return_hidden=False, **kwargs):
device = text.device
# embed text inputs
if not text_is_embedding:
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
# 0. Up-sampling text length
if self.length_regulator is not None:
text, text_lengths = self.length_regulator(text, text_lengths)
# 1. padding xvec
if xvec is not None and self.xvec_proj is not None:
xvec = xvec[:, :xvec_lengths.max()]
# random select a xvec from xvec matrix
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
xvec = self.xvec_proj(xvec)
text = text + xvec.unsqueeze(1)
hint_once("use xvec", "use_xvec")
# 1. Encoder
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
)
tokens, enc_probs = self.sampling_ids(
encoder_out,
sampling=sampling,
blank_penalty=blank_penalty,
return_probs=True,
)
# remove blanks (id=0) and convert token ids into the original format
tokens = [[x - 1] for x in tokens[0].cpu().tolist() if x > 0]
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
if not return_hidden:
return tokens
acoustic_embs = self.token_embedding(tokens.squeeze(-1))
acoustic_emb_lens = torch.tensor([acoustic_embs.shape[1]], dtype=torch.int64, device=device)
return tokens, acoustic_embs, acoustic_emb_lens

View File

@ -1,761 +0,0 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from torch import nn
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.lightconv import LightweightConvolution
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
from funasr.models.transformer.utils.mask import subsequent_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, 1, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, 1, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
causal=True,
):
super().__init__()
attention_dim = encoder_output_size
self.causal = causal
self.vocab_size = vocab_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
# Must set by the inheritance
self.decoders = None
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
if self.causal:
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
causal: bool = True,
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
causal=causal,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
embeds_id: int = -1,
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.embeds_id = embeds_id
self.attention_dim = attention_dim
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask
)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
if embeds_outputs is not None:
return x, olens, embeds_outputs
else:
return x, olens
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)

View File

@ -1,346 +0,0 @@
@
@
@
@
@!
@"
@#
@$
@'
@(
@)
@*
@,
@-
@.
@/
@:
@;
@<
@>
@?
@[
@]
@^
@_
@`
@a_c1
@a_c2
@a_c3
@a_c4
@a_c5
@aa0
@aa1
@aa2
@ae0
@ae1
@ae2
@ah0
@ah1
@ah2
@ai_c1
@ai_c2
@ai_c3
@ai_c4
@ai_c5
@an_c1
@an_c2
@an_c3
@an_c4
@an_c5
@ang_c1
@ang_c2
@ang_c3
@ang_c4
@ang_c5
@ao0
@ao1
@ao2
@ao_c1
@ao_c2
@ao_c3
@ao_c4
@ao_c5
@aw0
@aw1
@aw2
@ay0
@ay1
@ay2
@b
@b_c
@c_c
@ch
@ch_c
@d
@d_c
@dh
@e_c1
@e_c2
@e_c3
@e_c4
@e_c5
@eh0
@eh1
@eh2
@ei_c1
@ei_c2
@ei_c3
@ei_c4
@ei_c5
@en_c1
@en_c2
@en_c3
@en_c4
@en_c5
@eng_c1
@eng_c2
@eng_c3
@eng_c4
@eng_c5
@er0
@er1
@er2
@er_c1
@er_c2
@er_c3
@er_c4
@er_c5
@ey0
@ey1
@ey2
@f
@f_c
@g
@g_c
@ga
@ge
@go
@h_c
@hh
@i_c1
@i_c2
@i_c3
@i_c4
@i_c5
@ia_c1
@ia_c2
@ia_c3
@ia_c4
@ia_c5
@ian_c1
@ian_c2
@ian_c3
@ian_c4
@ian_c5
@iang_c1
@iang_c2
@iang_c3
@iang_c4
@iang_c5
@iao_c1
@iao_c2
@iao_c3
@iao_c4
@iao_c5
@ie_c1
@ie_c2
@ie_c3
@ie_c4
@ie_c5
@ih0
@ih1
@ih2
@ih_c1
@ih_c2
@ih_c3
@ih_c4
@ih_c5
@ii_c1
@ii_c2
@ii_c3
@ii_c4
@ii_c5
@in_c1
@in_c2
@in_c3
@in_c4
@in_c5
@ing_c1
@ing_c2
@ing_c3
@ing_c4
@ing_c5
@iong_c1
@iong_c2
@iong_c3
@iong_c4
@iong_c5
@iou_c1
@iou_c2
@iou_c3
@iou_c4
@iou_c5
@iy0
@iy1
@iy2
@j_c
@jh
@k
@k_c
@l
@l_c
@m
@m_c
@n
@n_c
@ng
@o_c1
@o_c2
@o_c3
@o_c4
@o_c5
@ong_c1
@ong_c2
@ong_c3
@ong_c4
@ong_c5
@ou_c1
@ou_c2
@ou_c3
@ou_c4
@ou_c5
@ouh
@ouj
@oull
@ouw
@ow0
@ow1
@ow2
@oy0
@oy1
@oy2
@p
@p_c
@q_c
@r
@r_c
@s
@s_c
@sh
@sh_c
@t
@t_c
@th
@u_c1
@u_c2
@u_c3
@u_c4
@u_c5
@ua_c1
@ua_c2
@ua_c3
@ua_c4
@ua_c5
@uai_c1
@uai_c2
@uai_c3
@uai_c4
@uai_c5
@uan_c1
@uan_c2
@uan_c3
@uan_c4
@uan_c5
@uang_c1
@uang_c2
@uang_c3
@uang_c4
@uang_c5
@uei_c1
@uei_c2
@uei_c3
@uei_c4
@uei_c5
@uen_c1
@uen_c2
@uen_c3
@uen_c4
@uen_c5
@uh0
@uh1
@uh2
@uo_c1
@uo_c2
@uo_c3
@uo_c4
@uo_c5
@uw0
@uw1
@uw2
@v
@v_c1
@v_c2
@v_c3
@v_c4
@v_c5
@van_c1
@van_c2
@van_c3
@van_c4
@van_c5
@ve_c1
@ve_c2
@ve_c3
@ve_c4
@ve_c5
@vn_c1
@vn_c2
@vn_c3
@vn_c4
@vn_c5
@w
@w_c
@xx_c
@y
@y_c
@z
@z_c
@zh
@zh_c
@{
@|
@}
@~
@—
@——
@
@
@“
@”
@…
@……
@‰
@℃
@
@○
@、
@。
@《
@》
@『
@』
@【
@】
@
@
@
@
@
@
@
@
@
@
@¥

View File

@ -1,31 +0,0 @@
from pathlib import Path
from typing import Iterable
from typing import Union
def build_tokenizer(
token_type: str,
bpemodel: Union[Path, str, Iterable[str]] = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
p_word2phn: float = 0.5,
):
if "whisper_rich_ttsfrd" in token_type:
from funasr.models.llm_asr.tts_text_tokenizer.whisper_tokenizer import WhisperRichTtsFrdTokenizer
return WhisperRichTtsFrdTokenizer(
token_path="multilingual_zh_ja_yue_char_del",
num_languages=105,
task=None,
language=None,
ttsfrd_type="ttsfrd_rich",
ttsfrd_model=bpemodel,
p_word2phn=p_word2phn,
)
else:
raise ValueError(
f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
)

View File

@ -1,175 +0,0 @@
import logging
from pathlib import Path
import re
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import warnings
import os
import json
import jamo
class TtsFrdRich:
"""
rich text info: phoneme + puncs + boundary + [word2phone]
"""
def __init__(self, remove_boundary=True, token_type="pronplus"):
super().__init__()
self.remove_boundary = remove_boundary
self.token_type = token_type
self.g2p = None
self.lang_type = None
self.lang_type_map = {"zh-cn": "pinyin", "en-us": "enus"}
@staticmethod
def contains_chinese(str):
return bool(re.search(r'[\u4e00-\u9fff]', str))
@staticmethod
def is_full_half_punctuation_string(s):
# 包含ASCII标点和常见全角标点
punctuation_pattern = r'[\u0000-\u002f\u003a-\u0040\u005b-\u0060\u007b-\u007f\u3000-\u303f\uff00-\uffef]'
# 使用re.findall找出所有匹配的字符
results = re.findall(punctuation_pattern, s)
# 如果字符串长度和匹配到的字符总数一样,说明全部是标点
return len(s) == len("".join(results))
def build(self, resource_dir, lang_type="Zh-CN"):
lang_type = lang_type.lower()
new_lang_type = self.lang_type_map[lang_type]
if self.g2p is None:
import ttsfrd
assert os.path.isdir(resource_dir)
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resource_dir)
self.g2p = fe
# self.lang_type = new_lang_type
self.set_lang_type(new_lang_type)
if self.lang_type != new_lang_type:
# self.lang_type = new_lang_type
self.set_lang_type(new_lang_type)
def set_lang_type(self, lang_type):
if lang_type == "enus":
self.g2p.set_lang_type(lang_type)
self.g2p.enable_pinyin_mix(True)
# self.g2p.set_breakmodel_index(0)
else:
self.g2p.set_lang_type(lang_type)
self.g2p.enable_pinyin_mix(True)
# self.g2p.set_breakmodel_index(1)
self.lang_type = lang_type
def set_token_type(self, token_type):
assert token_type in ["pronplus", "word2phn", "wordlist"], token_type
self.token_type = token_type
def __call__(self, text) -> Union[List[str], str]:
assert self.g2p is not None
if not self.contains_chinese(text):
if self.lang_type != "enus":
self.set_lang_type("enus")
else:
if self.lang_type != "pinyin":
self.set_lang_type("pinyin")
if self.token_type == "word2phn":
return self._get_word2phn(text)
elif self.token_type == "pronplus":
return self._get_pronplus(text)
elif self.token_type == "wordlist":
return self._get_wordlist(text)
else:
raise ValueError(f"only type: [pronplus, word2phn, wordlist] supported, now type: {self.token_type}")
def _get_pronplus(self, text) -> List[str]:
pronplus = self.g2p.get_frd_extra_info(text, 'pronplus')
if self.remove_boundary:
pronplus = pronplus.replace("/", "") # word boundary
pronplus = pronplus.replace("#", "") # syllable boundary
# pronplus = pronplus.replace("\n", "")
pronplus = pronplus.replace("\n", " ")
symbols: List[str] = []
for pron in pronplus.split(" "):
pron = pron.strip().lower()
if pron and pron[0].isalpha():
symbols.append(pron)
else:
symbols.extend([mark for mark in pron if mark])
return symbols
def text2tokens(self, line: str) -> List[str]:
json_str = self._get_word2phn(line)
data = json.loads(json_str)
retval = []
for one in data["word2phn"]:
for key, value in one.items():
if value is not None:
retval.extend([f"@{x}" for x in value])
else:
if key == " ":
key = "<|space|>"
retval.append(f"@{key}")
return retval
def tokens2text(self, tokens: Iterable[str]) -> str:
pass
def _get_wordlist(self, text) -> str:
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
return wordlist
def _get_word2phn(self, text) -> str:
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
wordlist_subs = wordlist.split("\n")
word2phn_info = []
prev_word_type = None
prev_word = None
for json_str in wordlist_subs:
if len(json_str) == 0:
continue
wordlist_info = json.loads(json_str)["wordlist"]
for word_info in wordlist_info:
is_english_word = True
this_phone_list = None
if word_info["syllables"] is None:
# punctuation
this_word_type = "punc"
pass
elif self.is_full_half_punctuation_string(word_info["name"]):
# punctuation, handle some g2p's mistakes spelling punctuation!!!
this_word_type = "punc"
pass
else:
this_phone_list = []
for syllable_info in word_info["syllables"]:
phn_count = syllable_info["phone_count"]
syllable_phone_list = syllable_info["pron_text"].split(" ")
assert len(syllable_phone_list) == phn_count, len(syllable_phone_list)
if "py_text" in syllable_info:
# chinese add tone info
syllable_phone_list[-1] = syllable_phone_list[-1]+str(syllable_info["tone"])
is_english_word = False
this_phone_list += syllable_phone_list
if is_english_word:
this_word_type = "en_word"
else:
this_word_type = "ch_word"
if this_word_type == "en_word":
if prev_word_type is None:
pass
elif prev_word_type == "en_word":
word2phn_info.append({" ": None})
elif prev_word_type == "punc":
if (prev_word not in ["\"", "\'", "(", "", "[", ""] and
prev_word.split(" ")[-1] not in ["\"", "\'", "(", "", "[", ""]):
word2phn_info.append({" ": None})
elif prev_word_type == "ch_word":
word2phn_info.append({" ": None})
elif this_word_type == "ch_word":
if prev_word_type is not None and prev_word_type == "en_word":
word2phn_info.append({" ": None})
elif this_word_type == "punc":
if word_info["name"] in ["("]:
word2phn_info.append({" ": None})
this_word2phn_dict = {word_info["name"]: this_phone_list}
word2phn_info.append(this_word2phn_dict)
prev_word_type = this_word_type
prev_word = list(word2phn_info[-1].keys())[0]
return json.dumps({"raw": text, "word2phn": word2phn_info}, ensure_ascii=False)

View File

@ -1,462 +0,0 @@
import base64
import os
import string
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
"minnan": "minnan",
"wuyu": "wuyu",
"dialect": "dialect",
"zh/en": "zh/en",
"en/zh": "en/zh",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
AUDIO_EVENT = {
"ASR": "ASR",
"AED": "AED",
"SER": "SER",
"Speech": "Speech",
"/Speech": "/Speech",
"BGM": "BGM",
"/BGM": "/BGM",
"Laughter": "Laughter",
"/Laughter": "/Laughter",
"Applause": "Applause",
"/Applause": "/Applause",
}
EMOTION = {
"HAPPY": "HAPPY",
"SAD": "SAD",
"ANGRY": "ANGRY",
"NEUTRAL": "NEUTRAL",
}
TTS_Vocal_Token = {
"TTS/B": "TTS/B",
"TTS/O": "TTS/O",
"TTS/Q": "TTS/Q",
"TTS/A": "TTS/A",
"TTS/CO": "TTS/CO",
"TTS/CL": "TTS/CL",
"TTS/H": "TTS/H",
"endofprompt": "endofprompt",
"sil": "sil",
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(3, 14)}
}
@dataclass
class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
return self.encoding.decode(token_ids, **kwargs)
def get_vocab_size(self) -> int:
return self.encoding.n_vocab
@cached_property
def eot(self) -> int:
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self) -> int:
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self) -> int:
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self) -> int:
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self) -> int:
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self) -> int:
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99, ttsfrd_name: Optional[str] = None):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
if name == "gpt2" or name == "multilingual":
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
else:
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
if ttsfrd_name is not None:
ttsfrd_vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{ttsfrd_name}.token")
assert os.path.isfile(ttsfrd_vocab_path), f"{ttsfrd_vocab_path} missing"
with open(ttsfrd_vocab_path, "r") as fr:
specials.extend([f"<|{line.strip()}|>" for line in fr if line])
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
encoding_path: Optional[str] = None,
ttsfrd_name: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
if encoding_path is not None:
encoding_name = encoding_path
encoding = get_encoding(name=encoding_name, num_languages=num_languages, ttsfrd_name=ttsfrd_name)
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)

View File

@ -1,164 +0,0 @@
import copy
import json
import os
import random
import re
from typing import Iterable, List, Union
import numpy as np
class WhisperRichTtsFrdTokenizer:
def __init__(
self,
token_path: str,
num_languages: int,
task: str = None,
language: str = None,
ttsfrd_type: str = None,
p_word2phn: float = 0.5,
ttsfrd_model: str = None,
):
import funasr.models.llm_asr.tts_text_tokenizer.voice_echo_rich_tokenizer as tokenizer
self.token_path = token_path
self.num_languages = num_languages
self.language = language
self.task = task
self.ttsfrd_type = ttsfrd_type
self.p_word2phn = p_word2phn
# print('token_path:',token_path)
if token_path == "whisper_en" or token_path == "whisper_gpt2" or token_path == "gpt2":
self.tokenizer = tokenizer.get_tokenizer(multilingual=False, num_languages=num_languages)
elif token_path == "whisper_multilingual" or token_path == "multilingual":
self.tokenizer = tokenizer.get_tokenizer(
multilingual=True, language=self.language, task=self.task, num_languages=num_languages
)
else:#
self.tokenizer = tokenizer.get_tokenizer(
multilingual=True, language=self.language, task=self.task, num_languages=num_languages,
encoding_path=token_path, ttsfrd_name=ttsfrd_type
)
if ttsfrd_model is not None and os.path.isdir(ttsfrd_model):
from funasr.models.llm_asr.tts_text_tokenizer.phoneme_tokenizer import TtsFrdRich
self.ttsfrd_tokenizer = TtsFrdRich(remove_boundary=True, token_type="word2phn")
self.ttsfrd_tokenizer.build(ttsfrd_model)
else:
self.ttsfrd_tokenizer = None
# self.tokenizer = copy.deepcopy(self.tokenizer)
def text_mixing(self, line: str) -> str:
try:
data_info = json.loads(line)
# ttsfrd_word2phn info
if isinstance(data_info, dict) and "raw" in data_info and "word2phn" in data_info:
raw_text = data_info["raw"]
ttsfrd_word2phn = data_info["word2phn"]
if random.random() < self.p_word2phn:
ret_text = ""
for ttsfrd_word in ttsfrd_word2phn:
for word_str, phn_list in ttsfrd_word.items():
if phn_list is not None:
if random.random() < self.p_word2phn:
ret_text = ret_text + "".join([f"<|@{p}|>" for p in phn_list])
else:
ret_text += word_str
else:
ret_text += word_str
else:
ret_text = raw_text
else:
ret_text = line
except json.JSONDecodeError:
ret_text = line
return ret_text
def get_num_vocabulary_size(self) -> int:
return self.tokenizer.get_vocab_size()
def text2ids(self, line: str, language: str) -> List[int]:
language_tok = "<|" + language + "|>"
assert language_tok in self.tokenizer.special_tokens, "Language token not found, lang: {}, line: {}".format(language_tok, line)
# line = re.sub(r'<(\d+\.\d+)>', r'<|\1|>', line)
pattern = re.compile(r'<|(\d+\.\d+)|>')
with_timestamps = pattern.search(line)
if with_timestamps:
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
allowed_special = set([f"<|{i * 0.02:.2f}|>" for i in range(1501)])
encoded_line = self.tokenizer.encode(line, allowed_special=allowed_special)
else:
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
encoded_line = self.tokenizer.encode(line)
return sot_tok + encoded_line
def ids2text(self, integers: Union[np.ndarray, Iterable[int]]) -> str:
return self.tokenizer.decode_with_timestamps(integers)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
return [self.tokenizer.decode_with_timestamps([i]) for i in integers]
def text2tokens(self, line: str, endofprompt="<|endofprompt|>", sil="<|sil|>") -> List[str]:
# keep prompt and sil unchanged
prompt_text = ""
st_sil, ed_sil = False, False
if endofprompt in line:
pos = line.find(endofprompt)
prompt_text = line[:pos+len(endofprompt)]
line = line[pos+len(endofprompt):]
if line.startswith(sil):
line = line[len(sil):]
st_sil = True
if line.endswith(sil):
line = line[:-len(sil)]
ed_sil = True
# token to phone and mixup
if self.ttsfrd_tokenizer is not None:
line = self.ttsfrd_tokenizer(line)
if self.ttsfrd_type is not None:
line = self.text_mixing(line)
# add prompt text and sil back
if st_sil:
line = sil + line
if ed_sil:
line = line + sil
line = prompt_text + line
return self.tokenizer.encode(line, allowed_special="all")
def tokens2text(self, tokens: Iterable[str]) -> str:
return self.tokenizer.decode_with_timestamps(tokens)
# def get_sot(self, sot_template: str, lang: str = None) -> List[int]:
# if lang is not None:
# lang = lang.replace("<", "").replace(">", "").replace("|", "")
# sot = sot_template.replace("LANG", lang)
# else:
# if "<|LANG|>" in sot_template:
# sot = sot_template.split("<|LANG|>", 1)[0]
# else:
# sot = sot_template
# sot_tok = self.tokenizer.encode(sot, allowed_special="all")
# return sot_tok
def get_sot(self, language: str = None, with_timestamps: bool = False) -> List[int]:
if language is not None:
language_tok = "<|" + language + "|>"
assert language_tok in self.tokenizer.special_tokens
if with_timestamps:
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
else:
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
else:
sot_tok = [self.tokenizer.sot]
return sot_tok
def get_all_languages(self) -> List[str]:
return list(self.tokenizer.all_language_codes)
def __repr__(self):
return (
f"{self.__class__.__name__}(model_type={self.token_path}, "
f"language={self.language}, ttsfrd={self.ttsfrd_type})"
)