add nar model

This commit is contained in:
志浩 2024-08-15 23:14:07 +08:00
parent a7b7d993ad
commit b09155c83b
17 changed files with 63375 additions and 71 deletions

View File

@ -498,6 +498,7 @@ class MaskedDiffWithXvec(BaseDiffWithXvec):
audio_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
batch_size = audio.shape[0]
# for data parallel

View File

@ -2593,44 +2593,58 @@ class LLMASR6(nn.Module):
self.eos = kwargs.get("eos", 151645)
# audio decoder related
self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024)
self.codebook_size = audio_decoder_conf.get("codebook_size", 4096)
self.lm_out_voc_size = self.codebook_size + 1
self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False)
self.concat_emb_hidden_norm = audio_decoder_conf.get("concat_emb_hidden_norm", False)
if self.concat_emb_hidden_norm:
self.hidden_norm = LayerNorm(llm_dim)
self.fusion_dropout = nn.Dropout(audio_decoder_conf.get("fusion_drop_rate", 0.0))
self.emb_norm = LayerNorm(llm_dim)
self.fusion_norm = LayerNorm(self.audio_decoder.embed_unit)
self.fusion_act = Swish()
audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim
self.audio_decoder_in_proj = torch.nn.Linear(
audio_decoder_in_proj_dim, self.audio_decoder.embed_unit
# tts text tokenizer related
tts_token_type = audio_decoder_conf.get("tts_token_type", "whisper_rich_ttsfrd")
ttsfrd_res_dir = audio_decoder_conf.get("ttsfrd_res_dir", "./ttsfrd/9.5.5")
from funasr.models.llm_asr.tts_text_tokenizer.build_tokenizer import build_tokenizer
self.tts_text_tokenizer = build_tokenizer(
tts_token_type,
bpemodel=ttsfrd_res_dir,
p_word2phn=1.0,
)
self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim)
self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit)
self.ad_sos_eos = 0
self.ad_task_id = 1
self.ad_ignore_id = -1
self.predict_nq = 1
from .label_smoothing_loss import LabelSmoothingLoss
self.criterion_ce = LabelSmoothingLoss(
size=self.lm_out_voc_size // self.predict_nq,
padding_idx=self.ad_ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
reduction=False,
from funasr.models.llm_asr.tts_models.e2e_model import UCTDXvecSlotModel
self.tts_model = UCTDXvecSlotModel(
**kwargs.get("tts_model_conf", {})
)
self.tts_dim_proj = nn.Linear(llm_dim, self.tts_model.output_size)
mel_decoder_name = kwargs.get("mel_decoder", None)
mel_decoder_conf = kwargs.get("mel_decoder_conf", None)
self.mel_decoder = self.build_mel_decoder(name=mel_decoder_name, conf=mel_decoder_conf)
# self.codebook_dim = audio_decoder_conf.get("codebook_dim", 1024)
# self.codebook_size = audio_decoder_conf.get("codebook_size", 4096)
# self.lm_out_voc_size = self.codebook_size + 1
# self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
# self.concat_emb_hidden = audio_decoder_conf.get("concat_emb_hidden", False)
# self.concat_emb_hidden_norm = audio_decoder_conf.get("concat_emb_hidden_norm", False)
# if self.concat_emb_hidden_norm:
# self.hidden_norm = LayerNorm(llm_dim)
# self.fusion_dropout = nn.Dropout(audio_decoder_conf.get("fusion_drop_rate", 0.0))
# self.emb_norm = LayerNorm(llm_dim)
# self.fusion_norm = LayerNorm(self.audio_decoder.embed_unit)
# self.fusion_act = Swish()
# audio_decoder_in_proj_dim = llm_dim * 2 if self.concat_emb_hidden else llm_dim
# self.audio_decoder_in_proj = torch.nn.Linear(
# audio_decoder_in_proj_dim, self.audio_decoder.embed_unit
# )
# self.codec_embedder = torch.nn.Embedding(self.codebook_size, self.codebook_dim)
# self.audio_decoder_embedding = torch.nn.Embedding(2, self.audio_decoder.embed_unit)
# self.ad_sos_eos = 0
# self.ad_task_id = 1
# self.ad_ignore_id = -1
# self.predict_nq = 1
#
# from .label_smoothing_loss import LabelSmoothingLoss
#
# self.criterion_ce = LabelSmoothingLoss(
# size=self.lm_out_voc_size // self.predict_nq,
# padding_idx=self.ad_ignore_id,
# smoothing=lsm_weight,
# normalize_length=length_normalized_loss,
# reduction=False,
# )
#
# mel_decoder_name = kwargs.get("mel_decoder", None)
# mel_decoder_conf = kwargs.get("mel_decoder_conf", None)
# self.mel_decoder = self.build_mel_decoder(name=mel_decoder_name, conf=mel_decoder_conf)
vocoder_name = kwargs.get("vocoder", None)
vocoder_conf = kwargs.get("vocoder_conf", None)
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf)
@ -2931,45 +2945,69 @@ class LLMASR6(nn.Module):
hidden_states_his_select = hidden_states_his_select.to(device=input_ids.device)
hidden_states_his_select_len = input_mask.sum(-1)
import pdb
# import pdb
#
# pdb.set_trace()
pdb.set_trace()
# if self.concat_emb_hidden:
# if not self.concat_emb_hidden_norm:
# hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
# hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
# else:
# outs = self.hidden_norm(hidden_states_select)
# outs = self.fusion_dropout(self.fusion_act(outs))
# # emb = model_outputs.hidden_states[0]
# emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
# outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
# hidden_states_select = self.fusion_act(self.fusion_norm(outs))
#
# nll, logits, target, target_lengths = self.nll(
# hidden_states_select, target_ids_len, codec[:, :, None], codec_len
# )
# output_mask = (
# ~make_pad_mask(target_lengths, maxlen=target_lengths.max())
# .to(hidden_states_select.device)
# .unsqueeze(-1)
# )
# total, batch_size = output_mask.sum() * self.predict_nq, nll.shape[0] * self.predict_nq
# denom = total if self.length_normalized_loss else batch_size
# loss = (nll * output_mask).sum() / denom
#
# with torch.no_grad():
# preds = torch.argmax(model_outputs.logits, -1)
# acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
# stats["acc"] = acc_att
#
# cc = logits.shape[-1]
# for i in range(self.predict_nq):
# acc = th_accuracy(
# logits[:, :, i, :].reshape(-1, cc), target[:, :, i], self.ad_ignore_id
# )
# stats[f"codec_acc_{i + 1}"] = acc
if self.concat_emb_hidden:
if not self.concat_emb_hidden_norm:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
else:
outs = self.hidden_norm(hidden_states_select)
outs = self.fusion_dropout(self.fusion_act(outs))
# emb = model_outputs.hidden_states[0]
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
nll, logits, target, target_lengths = self.nll(
hidden_states_select, target_ids_len, codec[:, :, None], codec_len
# nar tts model related
device = hidden_states_his_select.device
text = [self.tts_text_tokenizer.text2tokens(x) for x in target_ids]
text_lengths = [len(x) for x in text]
text = pad_list(text, pad_value=-1).long().to(device)
text_lengths = torch.tensor(text_lengths).to(audio_len)
# mute the "da" noise.
# TODO: make sure the sample rate is 22050.
audio[:, :int(0.02*22050)] = 0
hidden_states_his_select = self.tts_dim_proj(hidden_states_his_select)
tts_loss, tts_states, tts_weight = self.tts_model.forward(
text=text,
text_lengths=text_lengths,
speech_token=codec,
speech_token_lengths=codec_len,
audio=audio,
audio_lengths=audio_len,
prompt=hidden_states_his_select,
prompt_len=hidden_states_his_select_len
)
output_mask = (
~make_pad_mask(target_lengths, maxlen=target_lengths.max())
.to(hidden_states_select.device)
.unsqueeze(-1)
)
total, batch_size = output_mask.sum() * self.predict_nq, nll.shape[0] * self.predict_nq
denom = total if self.length_normalized_loss else batch_size
loss = (nll * output_mask).sum() / denom
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
cc = logits.shape[-1]
for i in range(self.predict_nq):
acc = th_accuracy(
logits[:, :, i, :].reshape(-1, cc), target[:, :, i], self.ad_ignore_id
)
stats[f"codec_acc_{i + 1}"] = acc
loss = loss + tts_loss
for key, value in tts_states.items():
stats[f"tts_{key}"] = value
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size

View File

@ -0,0 +1,365 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torch.nn.functional as F
from torch import nn
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
nonlinear.add_module('relu', nn.ReLU(inplace=True))
elif name == 'prelu':
nonlinear.add_module('prelu', nn.PReLU(channels))
elif name == 'batchnorm':
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
mean = x.mean(dim=dim)
std = x.std(dim=dim, unbiased=unbiased)
stats = torch.cat([mean, std], dim=-1)
if keepdim:
stats = stats.unsqueeze(dim=dim)
return stats
class StatsPool(nn.Module):
def forward(self, x):
return statistics_pooling(x)
class TDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
super(TDNNLayer, self).__init__()
if padding < 0:
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.linear = nn.Conv1d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
x = self.linear(x)
x = self.nonlinear(x)
return x
class CAMLayer(nn.Module):
def __init__(self,
bn_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
bias,
reduction=2):
super(CAMLayer, self).__init__()
self.linear_local = nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.linear_local(x)
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
context = self.relu(self.linear1(context))
m = self.sigmoid(self.linear2(context))
return y * m
def seg_pooling(self, x, seg_len=100, stype='avg'):
if stype == 'avg':
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
elif stype == 'max':
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
else:
raise ValueError('Wrong segment pooling type.')
shape = seg.shape
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
seg = seg[..., :x.shape[-1]]
return seg
class CAMDenseTDNNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
def forward(self, x):
x = self.bn_function(x)
x = self.cam_layer(self.nonlinear2(x))
return x
class CAMDenseTDNNBlock(nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
out_channels,
bn_channels,
kernel_size,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu',
memory_efficient=False):
super(CAMDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
out_channels=out_channels,
bn_channels=bn_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str,
memory_efficient=memory_efficient)
self.add_module('tdnnd%d' % (i + 1), layer)
def forward(self, x):
for layer in self:
x = torch.cat([x, layer(x)], dim=1)
return x
class TransitLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True,
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
x = self.linear(x)
return x
class DenseLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
if len(x.shape) == 2:
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
else:
x = self.linear(x)
x = self.nonlinear(x)
return x
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class FCM(nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
class CAMPPlus(nn.Module):
def __init__(self,
feat_dim=80,
embedding_size=512,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True):
super(CAMPPlus, self).__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
self.xvector = nn.Sequential(
OrderedDict([
('tdnn',
TDNNLayer(channels,
init_channels,
5,
stride=2,
dilation=1,
padding=-1,
config_str=config_str)),
]))
channels = init_channels
for i, (num_layers, kernel_size,
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
block = CAMDenseTDNNBlock(num_layers=num_layers,
in_channels=channels,
out_channels=growth_rate,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str,
memory_efficient=memory_efficient)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
'transit%d' % (i + 1),
TransitLayer(channels,
channels // 2,
bias=False,
config_str=config_str))
channels //= 2
self.xvector.add_module(
'out_nonlinear', get_nonlinear(config_str, channels))
self.xvector.add_module('stats', StatsPool())
self.xvector.add_module(
'dense',
DenseLayer(channels * 2, embedding_size, config_str='prelu'))
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
return x

View File

@ -0,0 +1,77 @@
import torch
def ctc_forced_align(
log_probs: torch.Tensor,
targets: torch.Tensor,
input_lengths: torch.Tensor,
target_lengths: torch.Tensor,
blank: int = 0,
ignore_id: int = -1,
) -> torch.Tensor:
"""Align a CTC label sequence to an emission.
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
"""
targets[targets == ignore_id] = blank
batch_size, input_time_size, _ = log_probs.size()
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
_t_a_r_g_e_t_s_ = torch.cat(
(
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
torch.full_like(targets[:, :1], blank),
),
dim=-1,
)
diff_labels = torch.cat(
(
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
),
dim=1,
)
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
padding_num = 2
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
for t in range(1, input_time_size):
prev = torch.stack(
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
)
prev_max_value, prev_max_idx = prev.max(dim=0)
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
backpointers[:, t, padding_num:] = prev_max_idx
l1l2 = best_score.gather(
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
)
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
for t in range(input_time_size - 1, 0, -1):
target_indices = path[:, t]
prev_max_idx = backpointers[bsz_indices, t, target_indices]
path[:, t - 1] += target_indices - prev_max_idx
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
return alignments

View File

@ -0,0 +1,866 @@
import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Union
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.hinter import hint_once
from funasr.models.transformer.utils.nets_utils import pad_list
import numpy as np
import random
def norm_and_sample_xvec(xvec, xvec_lengths):
xvec_list = []
for i, ilen in enumerate(xvec_lengths):
idx = random.randint(0, ilen - 1)
while torch.any(~torch.isfinite(xvec[i, idx])):
idx = random.randint(0, ilen - 1)
xvec_list.append(xvec[i, idx])
rand_xvec = torch.vstack(xvec_list)
rand_xvec = F.normalize(rand_xvec, dim=1)
return rand_xvec
class UpsampleCtcTokenDiffModel(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
vocab_size: int,
token_list: list,
token_vocab_size: int,
endofprompt_token_id: int = None,
text_encoder_conf: dict = None,
aggregator_conf: dict = None,
am_config: dict = None,
fm_config: dict = None,
xvec_size: int = None,
**kwargs
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.text_vocab_size = vocab_size
self.token_list = token_list
self.token_vocab_size = token_vocab_size
self.endofprompt_token_id = endofprompt_token_id
self.text_encoder_conf = text_encoder_conf
self.aggregator_conf = aggregator_conf
self.am_config = am_config
self.fm_config = fm_config
self.xvec_size = xvec_size
# build nn
self.text_embedding = nn.Embedding(vocab_size, output_size)
self.xvec_proj = None
if xvec_size is not None:
self.xvec_proj = nn.Linear(xvec_size, output_size)
self.text_encoder = self.build_text_encoder()
self.am_model = self.build_am_model()
self.fm_model = self.build_fm_model()
self.am_aggregator = self.build_aggregator()
self.fm_aggregator = self.build_aggregator()
# set optional parameters
self.xvec_drop_rate = kwargs.get('xvec_drop_rate', None)
self.use_prompt_as_xvec = kwargs.get('use_prompt_as_xvec', False)
if self.use_prompt_as_xvec:
self.spk_aggregator = self.build_aggregator()
spk_query = torch.randn(1, 1, self.output_size)
torch.nn.init.xavier_normal_(spk_query)
self.spk_query = torch.nn.Parameter(spk_query, requires_grad=True)
def build_aggregator(self):
name = self.aggregator_conf.pop("name", None)
model = None
if name == "transformer":
from funasr.models.llm_asr.tts_models.transformer_decoder import TransformerDecoder
model = TransformerDecoder(self.output_size, **self.aggregator_conf)
self.aggregator_conf["name"] = name
return model
def build_text_encoder(self):
name = self.text_encoder_conf.pop("name", None)
model = None
if name == "transformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**self.text_encoder_conf,
input_size=self.output_size,
use_cnn_module=False,
macaron_style=False,
)
elif name == "conformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**self.text_encoder_conf,
input_size=self.output_size,
)
self.text_encoder_conf["name"] = name
return model
def build_am_model(self):
name = self.am_config.pop("name", None)
model = None
if name == "nar_ctc_model":
from funasr.models.llm_asr.tts_models.nar_acoustic_model import NARCTCModel
model = NARCTCModel(**self.am_config)
elif name == "nar_ctc_prob_model":
from funasr.models.llm_asr.tts_models.nar_acoustic_model import NARCTCProbModel
model = NARCTCProbModel(**self.am_config)
self.am_config["name"] = name
return model
def build_fm_model(self):
name = self.fm_config.pop("name", None)
model = None
if name == "masked_diff_with_xvec":
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
model = MaskedDiffWithXvec(**self.fm_config)
self.fm_config["name"] = name
return model
def split_prompt(
self,
text_emb: torch.Tensor,
text_emb_lens: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
):
prompts, prompt_lens = [], []
outs, outs_lens = [], []
batch_size = text.shape[0]
for i in range(batch_size):
delta = text_emb_lens[i] - text_lens[i]
# 1 for exclude <|endofprompt|> token
pos = torch.where(text[i] == self.endofprompt_token_id)[0][0].item() + delta + 1
_x = text_emb[i, pos:text_emb_lens[i]]
outs.append(_x)
outs_lens.append(_x.shape[0])
_prompt = text_emb[i, :pos]
prompts.append(_prompt)
prompt_lens.append(_prompt.shape[0])
outs = pad_list(outs, pad_value=0.0)
outs_lens = torch.tensor(outs_lens, dtype=torch.int64, device=text_emb.device)
prompts = pad_list(prompts, pad_value=0.0)
prompt_lens = torch.tensor(prompt_lens, dtype=torch.int64, device=text_emb.device)
return prompts, prompt_lens, outs, outs_lens
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech_token: torch.Tensor,
speech_token_lengths: torch.Tensor,
audio: torch.Tensor,
audio_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
text = text[:, :text_lengths.max()]
speech_token = speech_token[:, :speech_token_lengths.max()]
audio = audio[:, :audio_lengths.max()]
batch_size = text.shape[0]
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
if self.use_prompt_as_xvec:
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1]*batch_size).to(prompt_lens)
)
endofprompt_emb = self.text_embedding(torch.tensor([self.endofprompt_token_id]*batch_size).to(text).unsqueeze(1))
prompt = torch.cat([prompt_xvec, endofprompt_emb], dim=1)
prompt_lens = torch.tensor([2]*batch_size).to(prompt_lens)
hint_once("using prompt as speaker embedding.", "use_prompt_spk_emb")
# random select a xvec from xvec matrix
if not self.use_prompt_as_xvec and self.xvec_proj is not None and xvec is not None:
xvec = xvec[:, :xvec_lengths.max()]
rand_xvec = norm_and_sample_xvec(xvec, xvec_lengths)
rand_xvec = self.xvec_proj(rand_xvec)
if self.xvec_drop_rate is not None:
xvec_mask = (torch.rand((rand_xvec.shape[0], 1)) >= self.xvec_drop_rate).to(rand_xvec)
rand_xvec = rand_xvec * xvec_mask
hint_once(f"randomly drop out xvec with mask {xvec_mask.squeeze()}", "xvec_drop_out")
rand_xvec = rand_xvec.unsqueeze(1)
prompt = torch.cat([rand_xvec, prompt], dim=1)
prompt_lens = prompt_lens + 1
hint_once("using speaker embedding as slot.", "use_spk_emb")
# remove prompt text
# prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
# text_emb, text_emb_lengths, text, text_lengths
# )
outs_tuple = self.text_encoder(text_emb, ilens=text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = text_emb_lengths
text_enc, _ = self.am_aggregator(
prompt, prompt_lens,
text_enc, text_enc_lens
)
states = dict(
batch_size=float(batch_size),
text_len=float(text_emb.shape[1]),
speech_len=float(speech_token.shape[1]),
token_text_ratio=float(speech_token.shape[1]) / float(text_emb.shape[1]),
)
# forward AM model
am_retvals = self.am_model.force_align_text(
speech_token, speech_token_lengths,
text_enc, text_enc_lens,
**kwargs
)
am_loss, aligned_token_emb, am_states = am_retvals
# update AM states
for key, val in am_states.items():
states[f"am_{key}"] = val
aligned_token_emb, _ = self.fm_aggregator(
prompt, prompt_lens,
aligned_token_emb, speech_token_lengths
)
# forward FM model
fm_loss, fm_states, _ = self.fm_model.forward(
aligned_token_emb, speech_token_lengths,
audio, audio_lengths,
**kwargs,
)
# update FM states
for key, val in fm_states.items():
states[f"fm_{key}"] = val
loss = am_loss + fm_loss
states["loss"] = loss.item()
loss, states, weight = force_gatherable((loss, states, batch_size), loss.device)
return loss, states, weight
def inference(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
blank_penalty = kwargs.get("blank_penalty", 0.0)
sampling = kwargs.get("sampling", "greedy")
prompt_dict = kwargs.get("prompt_dict", {})
prompt_token = prompt_dict.get("prompt_token", (None, None))
prompt_audio = prompt_dict.get("prompt_audio", (None, None))
# fully un-causal mode
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
batch_size = text.shape[0]
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
if self.use_prompt_as_xvec:
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
)
endofprompt_emb = self.text_embedding(
torch.tensor([self.endofprompt_token_id] * batch_size).to(text).unsqueeze(1))
prompt = torch.cat([prompt_xvec, endofprompt_emb], dim=1)
prompt_lens = torch.tensor([2] * batch_size).to(prompt_lens)
hint_once("using prompt as speaker embedding.", "use_prompt_spk_emb")
# using the xvec
if self.xvec_proj is not None and not self.use_prompt_as_xvec:
if xvec is not None:
hint_once("using speaker embedding as slot.", "use_spk_emb")
xvec = xvec[:, :xvec_lengths.max()]
rand_xvec = norm_and_sample_xvec(xvec, xvec_lengths)
rand_xvec = self.xvec_proj(rand_xvec)
rand_xvec = rand_xvec.unsqueeze(1)
else:
hint_once("using zeros as speaker embedding.", "use_spk_emb")
rand_xvec = torch.zeros([text_emb.shape[0], 1, self.output_size]).to(text_emb)
prompt = torch.cat([rand_xvec, prompt], dim=1)
prompt_lens = prompt_lens + 1
outs_tuple = self.text_encoder(text_emb, ilens=text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = text_emb_lengths
text_enc, _ = self.am_aggregator(
prompt, prompt_lens,
text_enc, text_enc_lens
)
# forward AM model
tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference(
text_enc, text_enc_lens,
sampling=sampling,
blank_penalty=blank_penalty,
text_is_embedding=True,
return_hidden=True,
use_causal_prob=use_causal_prob,
)
if isinstance(tokens, tuple):
tokens, fa_tokens = tokens
aligned_token_emb, _ = self.fm_aggregator(
prompt, prompt_lens,
aligned_token_emb, aligned_token_lens
)
# forward FM model
feat = self.fm_model.inference(
aligned_token_emb, aligned_token_lens,
prompt=dict(
prompt_text=prompt_token,
prompt_audio=prompt_audio,
),
**kwargs,
)
feat = self.rms_rescale_feat(feat)
return tokens, feat
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1):
feat_power = feat.exp().sum(1)
# not silence
if feat_power.max() > feat_sil_th:
mask = feat_power > feat_sil_th
feat_rms = torch.sqrt(torch.mean(torch.square(feat_power)))
feat = feat + mask.unsqueeze(1) * np.log(target_feat_rms / feat_rms.cpu().numpy().item())
return feat
def get_hop_lens(self, fa_tokens, lookahead_size):
if lookahead_size == 0:
return 0, 0
fa_tokens = fa_tokens[0].cpu().tolist()
upsample_rate = np.cumprod(self.am_model.encoder.upsample_ratios)[-1]
lookahead_tokens = [[x-1] for x in fa_tokens[-lookahead_size*upsample_rate:] if x > 0]
lookahead_token_len = len(lookahead_tokens)
return lookahead_token_len
def streaming_inference(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
device = text.device
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
# streaming related config
chunk_size = kwargs.get("streaming_chunk_size", 1)
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
try:
lookahead_size = self.am_model.encoder.pre_lookahead_len
except AttributeError:
lookahead_size = 0
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
f"pre lookahead size={lookahead_size}.",
"pre_lookahead_len")
given_rtf = kwargs.get("given_rtf", 0.5)
blank_penalty = kwargs.get("blank_penalty", 0.0)
sampling = kwargs.get("sampling", "greedy")
prompt_dict = kwargs.get("prompt_dict", {})
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
streaming_mode = kwargs.get("streaming_mode", "v2")
if prompt_token[0] is None:
prompt_token[0] = torch.zeros([1, 0, self.output_size], device=device, dtype=torch.float32)
prompt_token[1] = torch.tensor([0], device=device, dtype=torch.long)
if prompt_audio[0] is None:
prompt_audio[0] = torch.zeros(
[1, 0, self.fm_model.mel_extractor.num_mels],
device=device, dtype=torch.float32
)
prompt_audio[1] = torch.tensor([0], device=device, dtype=torch.long)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
batch_size = text.shape[0]
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
if self.use_prompt_as_xvec:
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
)
endofprompt_emb = self.text_embedding(
torch.tensor([self.endofprompt_token_id] * batch_size).to(text).unsqueeze(1))
prompt = torch.cat([prompt_xvec, endofprompt_emb], dim=1)
prompt_lens = torch.tensor([2] * batch_size).to(prompt_lens)
hint_once("using prompt as speaker embedding.", "use_prompt_spk_emb")
# using the xvec
if self.xvec_proj is not None and not self.use_prompt_as_xvec:
if xvec is not None:
hint_once("using speaker embedding as slot.", "use_spk_emb")
xvec = xvec[:, :xvec_lengths.max()]
rand_xvec = norm_and_sample_xvec(xvec, xvec_lengths)
rand_xvec = self.xvec_proj(rand_xvec)
rand_xvec = rand_xvec.unsqueeze(1)
else:
hint_once("using zeros as speaker embedding.", "use_spk_emb")
rand_xvec = torch.zeros([text_emb.shape[0], 1, self.output_size]).to(text_emb)
prompt = torch.cat([rand_xvec, prompt], dim=1)
prompt_lens = prompt_lens + 1
chunk_id = 0
chunk_start = 0
while True:
_st_time = time.time()
_size = max(int(round(chunk_size / (given_rtf ** chunk_id))), chunk_size_maxium)
chunk_end = chunk_start + _size
chunk_text_emb = text_emb[:, :chunk_end+lookahead_size]
chunk_text_emb_lengths = torch.tensor([chunk_text_emb.shape[1]], dtype=torch.long, device=device)
outs_tuple = self.text_encoder(chunk_text_emb, ilens=chunk_text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = chunk_text_emb_lengths
text_enc, _ = self.am_aggregator(
prompt, prompt_lens,
text_enc, text_enc_lens
)
# forward AM model
tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference(
text_enc, text_enc_lens,
sampling=sampling,
blank_penalty=blank_penalty,
text_is_embedding=True,
return_hidden=True,
use_causal_prob=use_causal_prob,
)
token_hop_len, mel_hop_len = 0, 0
if isinstance(tokens, tuple):
tokens, fa_tokens = tokens
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
# exclude empty tokens.
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
aligned_token_emb, _ = self.fm_aggregator(
prompt, prompt_lens,
aligned_token_emb, aligned_token_lens
)
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
cur_token_len = aligned_token_lens - prompt_token[1]
# v2: excluding lookahead tokens for not-last packages
if streaming_mode == "v2":
if chunk_end + lookahead_size < text_emb.shape[1]:
cur_token = cur_token[:, :cur_token.shape[1]-token_hop_len, :]
cur_token_len = cur_token_len - token_hop_len
# forward FM model
feat = self.fm_model.inference(
cur_token, cur_token_len,
prompt=dict(
prompt_text=prompt_token,
prompt_audio=prompt_audio,
),
**kwargs,
)
feat = self.rms_rescale_feat(feat)
cost = time.time() - _st_time
if chunk_id == 0:
logging.info(f"First package delay: {cost*1000.0:.2f}ms")
print_token = tokens.cpu().squeeze().tolist()
logging.info(f"pack {chunk_id}: valid_tokens: {print_token[:len(print_token)-token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token)-token_hop_len:]}.")
if streaming_mode == "v1":
# v1: excluding lookahead parts for not-last packages
if chunk_end + lookahead_size < text_emb.shape[1]:
cur_token = cur_token[:, :cur_token.shape[1]-token_hop_len, :]
feat = feat[:, :, :feat.shape[2] - mel_hop_len]
if streaming_mode == "v2":
# v2: reback token and mel feat
if chunk_end + lookahead_size < text_emb.shape[1]:
text_reback = 2 if chunk_id == 0 else 4
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
token_reback = token_hop_len_2 - token_hop_len
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
feat = feat[:, :, :feat.shape[2] - feat_reback]
chunk_end = chunk_end - text_reback
# update values and lens of prompt token and audio
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
chunk_id += 1
chunk_start = chunk_end
if chunk_end + lookahead_size >= text_emb.shape[1]:
break
return tokens, prompt_audio[0].transpose(1, 2)
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
pass
class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
def __init__(self, input_size: int, output_size: int, vocab_size: int, token_list: list, token_vocab_size: int,
endofprompt_token_id: int = None, text_encoder_conf: dict = None, aggregator_conf: dict = None,
am_config: dict = None, fm_config: dict = None, xvec_size: int = None, **kwargs):
super().__init__(input_size, output_size, vocab_size, token_list, token_vocab_size, endofprompt_token_id,
text_encoder_conf, aggregator_conf, am_config, fm_config, xvec_size, **kwargs)
# remove am and fm aggregator
self.am_aggregator = None
self.fm_aggregator = None
# build speaker aggregator for Prompt Text
self.spk_aggregator = self.build_aggregator()
spk_query = torch.randn(1, 1, self.output_size)
torch.nn.init.xavier_normal_(spk_query)
self.spk_query = torch.nn.Parameter(spk_query, requires_grad=True)
self.prompt_xvec_proj = nn.Linear(self.output_size, self.xvec_size)
# build xvec extractor for Mel spectrum
# self.mel_spec_fn = self.fm_model.mel_extractor
# from funasr.models.llm_asr.tts_models.campp_encoder import CAMPPlus
# self.mel_xvec_fn = CAMPPlus(self.mel_spec_fn.num_mels, self.xvec_size)
# self.audio_prompt_lens = kwargs.get("audio_prompt_lens", [0.3, 1.0])
# text_mel_xvec_rand_ratios = kwargs.get("text_mel_xvec_rand_ratio", [0.3, 0.3, 0.4])
# self.register_buffer("text_mel_xvec_rand_ratios", torch.tensor(text_mel_xvec_rand_ratios))
def forward(self, text: torch.Tensor, text_lengths: torch.Tensor, speech_token: torch.Tensor,
speech_token_lengths: torch.Tensor, audio: torch.Tensor, audio_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
text = text[:, :text_lengths.max()]
speech_token = speech_token[:, :speech_token_lengths.max()]
audio = audio[:, :audio_lengths.max()]
batch_size = text.shape[0]
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
if "prompt" in kwargs and "prompt_len" in kwargs:
prompt = kwargs["prompt"]
prompt_lens = kwargs["prompt_len"]
else:
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
# textual prompt xvec
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
)
prompt_xvec = self.prompt_xvec_proj(prompt_xvec)
# # mel prompt xvec
# audio_rand_lens = torch.rand_like(audio_lengths, dtype=torch.float32) * (self.audio_prompt_lens[1] - self.audio_prompt_lens[0]) + self.audio_prompt_lens[0]
# audio_rand_lens = (audio_rand_lens * audio_lengths).round().long()
# audio_rand_lens = torch.clamp(audio_rand_lens, min=round(self.mel_spec_fn.sampling_rate*0.5))
# audio_prompt = [audio[i, :audio_rand_lens[i]] for i in range(batch_size)]
# audio_rand_lens = torch.tensor([x.shape[0] for x in audio_prompt]).to(text_lengths)
# audio_prompt = pad_list(audio_prompt, 0.0)
# mel_feat, feat_lens = self.mel_spec_fn(audio_prompt, audio_rand_lens)
# mel_xvec = self.mel_xvec_fn(mel_feat).unsqueeze(1)
#
# # random select a xvec from xvec matrix
# xvec = xvec[:, :xvec_lengths.max()]
#
# # random using prompt, mel and input xvecs
# mixup_rand = self.text_mel_xvec_rand_ratios.multinomial(batch_size, replacement=True).unsqueeze(1).unsqueeze(2)
# rand_xvec = (
# (mixup_rand == 0) * prompt_xvec +
# (mixup_rand == 1) * mel_xvec +
# (mixup_rand == 2) * xvec
# )
rand_xvec = prompt_xvec
rand_xvec_lens = torch.tensor([1] * batch_size).to(text_emb_lengths)
outs_tuple = self.text_encoder(text_emb, ilens=text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = text_emb_lengths
states = dict(
batch_size=float(batch_size),
text_len=float(text_emb.shape[1]),
speech_len=float(speech_token.shape[1]),
token_text_ratio=float(speech_token.shape[1]) / float(text_emb.shape[1]),
)
# forward AM model
am_retvals = self.am_model.force_align_text(
speech_token, speech_token_lengths,
text_enc, text_enc_lens,
rand_xvec, rand_xvec_lens,
**kwargs
)
am_loss, aligned_token_emb, am_states = am_retvals
# update AM states
for key, val in am_states.items():
states[f"am_{key}"] = val
# forward FM model
fm_loss, fm_states, _ = self.fm_model.forward(
aligned_token_emb, speech_token_lengths,
audio, audio_lengths,
rand_xvec, rand_xvec_lens,
**kwargs,
)
# update FM states
for key, val in fm_states.items():
states[f"fm_{key}"] = val
loss = am_loss + fm_loss
states["loss"] = loss.item()
loss, states, weight = force_gatherable((loss, states, batch_size), loss.device)
return loss, states, weight
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
blank_penalty = kwargs.get("blank_penalty", 0.0)
sampling = kwargs.get("sampling", "greedy")
prompt_dict = kwargs.get("prompt_dict", {})
prompt_token = prompt_dict.get("prompt_token", (None, None))
prompt_audio = prompt_dict.get("prompt_audio", (None, None))
# fully un-causal mode
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
batch_size = text.shape[0]
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
if xvec is not None:
# using the xvec
hint_once("using speaker embedding for slot.", "use_spk_emb")
xvec = xvec[:, :xvec_lengths.max()]
else:
# textual prompt xvec
hint_once("using textual prompt for slot.", "use_spk_emb")
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
)
xvec = self.prompt_xvec_proj(prompt_xvec)
xvec_lengths = torch.tensor([1] * batch_size).to(text_lengths)
outs_tuple = self.text_encoder(text_emb, ilens=text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = text_emb_lengths
# forward AM model
tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference(
text_enc, text_enc_lens,
xvec, xvec_lengths,
sampling=sampling,
blank_penalty=blank_penalty,
text_is_embedding=True,
return_hidden=True,
use_causal_prob=use_causal_prob,
)
if isinstance(tokens, tuple):
tokens, fa_tokens = tokens
# forward FM model
feat = self.fm_model.inference(
aligned_token_emb, aligned_token_lens,
xvec, xvec_lengths,
prompt=dict(
prompt_text=prompt_token,
prompt_audio=prompt_audio,
),
**kwargs,
)
feat = self.rms_rescale_feat(feat)
return tokens, feat
def streaming_inference(self, text: torch.Tensor, text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
device = text.device
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
# streaming related config
chunk_size = kwargs.get("streaming_chunk_size", 1)
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
try:
lookahead_size = self.am_model.encoder.pre_lookahead_len
except AttributeError:
lookahead_size = 0
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
f"pre lookahead size={lookahead_size}.",
"pre_lookahead_len")
given_rtf = kwargs.get("given_rtf", 0.5)
blank_penalty = kwargs.get("blank_penalty", 0.0)
sampling = kwargs.get("sampling", "greedy")
prompt_dict = kwargs.get("prompt_dict", {})
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
streaming_mode = kwargs.get("streaming_mode", "v2")
if prompt_token[0] is None:
prompt_token[0] = torch.zeros([1, 0, self.output_size], device=device, dtype=torch.float32)
prompt_token[1] = torch.tensor([0], device=device, dtype=torch.long)
if prompt_audio[0] is None:
prompt_audio[0] = torch.zeros(
[1, 0, self.fm_model.mel_extractor.num_mels],
device=device, dtype=torch.float32
)
prompt_audio[1] = torch.tensor([0], device=device, dtype=torch.long)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
text_emb_lengths = text_lengths
batch_size = text.shape[0]
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
text_emb, text_emb_lengths, text, text_lengths
)
if xvec is not None:
# using speaker embedding
hint_once("using speaker embedding for slot.", "use_spk_emb")
xvec = xvec[:, :xvec_lengths.max()]
else:
# textual prompt xvec
hint_once("using textual prompt for slot.", "use_spk_emb")
prompt_xvec, _ = self.spk_aggregator(
prompt, prompt_lens,
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
)
xvec = self.prompt_xvec_proj(prompt_xvec)
xvec_lengths = torch.tensor([1] * batch_size).to(text_lengths)
chunk_id = 0
chunk_start = 0
while True:
_st_time = time.time()
_size = max(int(round(chunk_size / (given_rtf ** chunk_id))), chunk_size_maxium)
chunk_end = chunk_start + _size
chunk_text_emb = text_emb[:, :chunk_end + lookahead_size]
chunk_text_emb_lengths = torch.tensor([chunk_text_emb.shape[1]], dtype=torch.long, device=device)
outs_tuple = self.text_encoder(chunk_text_emb, ilens=chunk_text_emb_lengths)
text_enc = outs_tuple[0]
text_enc_lens = chunk_text_emb_lengths
# forward AM model
tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference(
text_enc, text_enc_lens,
xvec, xvec_lengths,
sampling=sampling,
blank_penalty=blank_penalty,
text_is_embedding=True,
return_hidden=True,
use_causal_prob=use_causal_prob,
)
token_hop_len, mel_hop_len = 0, 0
if isinstance(tokens, tuple):
tokens, fa_tokens = tokens
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
# exclude empty tokens.
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
cur_token_len = aligned_token_lens - prompt_token[1]
# v2: excluding lookahead tokens for not-last packages
if streaming_mode == "v2":
if chunk_end + lookahead_size < text_emb.shape[1]:
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
cur_token_len = cur_token_len - token_hop_len
# forward FM model
feat = self.fm_model.inference(
cur_token, cur_token_len,
xvec, xvec_lengths,
prompt=dict(
prompt_text=prompt_token,
prompt_audio=prompt_audio,
),
**kwargs,
)
feat = self.rms_rescale_feat(feat)
cost = time.time() - _st_time
if chunk_id == 0:
logging.info(f"First package delay: {cost * 1000.0:.2f}ms")
print_token = tokens.cpu().squeeze().tolist()
logging.info(f"pack {chunk_id}: valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
if streaming_mode == "v1":
# v1: excluding lookahead parts for not-last packages
if chunk_end + lookahead_size < text_emb.shape[1]:
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
feat = feat[:, :, :feat.shape[2] - mel_hop_len]
if streaming_mode == "v2":
# v2: reback token and mel feat
if chunk_end + lookahead_size < text_emb.shape[1]:
text_reback = 2 if chunk_id == 0 else 4
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
token_reback = token_hop_len_2 - token_hop_len
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
feat = feat[:, :, :feat.shape[2] - feat_reback]
chunk_end = chunk_end - text_reback
# update values and lens of prompt token and audio
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
chunk_id += 1
chunk_start = chunk_end
if chunk_end + lookahead_size >= text_emb.shape[1]:
break
return tokens, prompt_audio[0].transpose(1, 2)

View File

@ -0,0 +1,542 @@
import logging
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import nn
from funasr.models.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.models.transformer.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.utils.nets_utils import get_activation
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.mask import subsequent_mask, causal_block_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError,
check_short_utt, Conv2dSubsamplingPad
)
import torch.nn.functional as F
from funasr.models.llm_asr.conformer_encoder import ConvolutionModule, EncoderLayer
from funasr.models.ctc.ctc import CTC
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True,
out_channels=None, name="conv", channel_first=True, stride=2, causal=False):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.channel_first = channel_first
self.stride = stride
self.causal = causal
self.conv = None
if use_conv_transpose:
# transpose conv doesn't support causal mode.
assert not causal
kernel_size = stride*2 + stride % 2
padding = (kernel_size - stride) // 2
self.conv = nn.ConvTranspose1d(channels, self.out_channels, kernel_size, stride, padding)
elif use_conv:
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(
self.channels, self.out_channels, stride*2+1, stride=1,
padding=0,
)
def forward(self, inputs, input_lengths=None):
if not self.channel_first:
inputs = inputs.transpose(1, 2).contiguous()
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
outputs = self.conv(inputs)
if not self.channel_first:
outputs = outputs.transpose(1, 2).contiguous()
return outputs, input_lengths * self.stride
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
if self.use_conv:
if not self.causal:
outputs = F.pad(outputs, (self.stride, self.stride))
else:
outputs = F.pad(outputs, (self.stride*2, 0))
outputs = self.conv(outputs)
if not self.channel_first:
outputs = outputs.transpose(1, 2).contiguous()
return outputs, input_lengths * self.stride
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len:int = 1):
super().__init__()
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
kernel_size=pre_lookahead_len+1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
kernel_size=3, stride=1, padding=0,
)
def forward(self, inputs, ilens):
"""
inputs: (batch_size, seq_len, channels)
"""
outputs = inputs.transpose(1, 2).contiguous()
# look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
outputs = F.leaky_relu(self.conv1(outputs))
# outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
# residual connection
outputs = (outputs + inputs) * mask
return outputs, ilens
class UpsampleConformerEncoder(nn.Module):
"""Progressive upsampling Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
upsample_blocks: int = 3,
upsample_attn_layers: int = 2,
upsample_ratios: tuple = None,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
causal: bool = False,
skip: bool = False,
channel_first: bool = False,
use_causal_prob: float = None,
pre_lookahead_len: int = None,
):
super().__init__()
self._output_size = output_size
self.causal = causal
self.skip = skip
self.channel_first = channel_first
self.pre_lookahead_len = pre_lookahead_len
self.use_causal_prob = use_causal_prob
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if selfattention_layer_type == "rel_selfattn":
selfattention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert selfattention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2dpad":
self.embed = Conv2dSubsamplingPad(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if pre_lookahead_len is not None:
self.pre_lookahead_layer = PreLookaheadLayer(output_size, pre_lookahead_len)
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate=0.0,
),
)
self.upsample_blocks = nn.ModuleList()
if upsample_ratios is None:
upsample_ratios = [2] * upsample_blocks
self.upsample_ratios = upsample_ratios
assert upsample_blocks == len(upsample_ratios)
for i in range(upsample_blocks):
if not causal:
upsample_conv_block = Upsample1D(
channels=output_size, use_conv=False, use_conv_transpose=True,
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=False,
)
else:
upsample_conv_block = Upsample1D(
channels=output_size, use_conv=True, use_conv_transpose=False,
out_channels=output_size, channel_first=False, stride=upsample_ratios[i], causal=True,
)
upsample_attn_block = repeat(
upsample_attn_layers,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate=0.0,
),
)
attn_input_layer = torch.nn.Sequential(
torch.nn.Linear(output_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
self.upsample_blocks.append(nn.ModuleList([upsample_conv_block, attn_input_layer, upsample_attn_block]))
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def rand_mix_masks(self, causal, noncausal):
use_causal = (torch.rand([causal.shape[0], 1, 1]) <= self.uni_encoder_prob).to(causal)
masks = use_causal * causal + (1 - use_causal) * noncausal
return masks
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
raw_input = xs_pad
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if ilens is not None:
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
else:
masks = torch.ones(
xs_pad.shape[0], 1, xs_pad.shape[1],
dtype=torch.bool, device=xs_pad.device
)
if self.use_causal_prob is not None:
use_causal = (torch.rand([xs_pad.shape[0], 1, 1]) <= self.use_causal_prob).to(xs_pad)
else:
use_causal = torch.ones([xs_pad.shape[0], 1, 1]).to(xs_pad)
if self.causal:
causal_mask = subsequent_mask(
xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype
).unsqueeze(0)
causal_mask = masks & causal_mask
# whether to train causal & non-causal in a single model
masks = use_causal * causal_mask + (1 - use_causal) * masks
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
or isinstance(self.embed, Conv2dSubsamplingPad)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
if self.pre_lookahead_len is not None:
xs = xs_pad
if isinstance(xs_pad, tuple):
xs = xs_pad[0]
xs, _ = self.pre_lookahead_layer(xs, ilens)
if isinstance(xs_pad, tuple):
xs_pad = (xs, xs_pad[1])
# 1. modeling on inputs
intermediate_outs = []
xs_pad, masks = self.encoders(xs_pad, masks)
# 2. progressive upsampling
outs, olens = xs_pad, ilens
total_ratio = 1
for up_ratio, layer in zip(self.upsample_ratios, self.upsample_blocks):
up_layer, attn_input_layer, attn_layer = layer
if isinstance(outs, tuple):
outs = outs[0]
outs, olens = up_layer(outs, olens)
masks = (~make_pad_mask(olens)[:, None, :]).to(outs.device)
total_ratio = total_ratio * up_ratio
if self.causal:
causal_mask = causal_block_mask(
outs.shape[1], total_ratio, device=outs.device, dtype=masks.dtype
).unsqueeze(0)
causal_mask = masks & causal_mask
masks = use_causal * causal_mask + (1 - use_causal) * masks
outs = attn_input_layer(outs)
outs, _ = attn_layer(outs, masks)
xs_pad = outs
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if self.skip:
xs_pad = xs_pad + raw_input
# olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
if ilens is not None:
return xs_pad, olens, None
else:
return xs_pad

View File

@ -0,0 +1,618 @@
import logging
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import torch.nn.functional as F
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.llm_asr.label_smoothing_loss import LabelSmoothingLoss
from copy import deepcopy
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import pad_list
import random
import numpy as np
from funasr.utils.hinter import hint_once
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.llm_asr.tts_models.ctc_alignment import ctc_forced_align
from torch.nn.utils.rnn import pad_sequence
import itertools
from distutils.version import LooseVersion
from contextlib import contextmanager
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class NARCTCModel(nn.Module):
def __init__(
self,
input_size: int,
vocab_size: int,
encoder: Union[nn.Module, dict],
decoder: Optional[nn.Module] = None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.decoder = decoder
self.encoder = encoder if isinstance(encoder, nn.Module) else self.build_encoder(encoder)
self.output_size = self.encoder.output_size()
self.ignore_id = ignore_id
self.vocab_size = vocab_size
self.ctc_weight = ctc_weight
# build ctc module
from funasr.models.ctc.ctc import CTC
ctc_conf = kwargs.pop("ctc_conf", {})
self.ctc = CTC(vocab_size, encoder_output_size=self.output_size, **ctc_conf)
self.text_embedding = torch.nn.Embedding(self.vocab_size, input_size)
self.token_embedding = torch.nn.Embedding(vocab_size, input_size)
xvec_size = kwargs.get("xvec_size", None)
if xvec_size is not None:
self.xvec_proj = torch.nn.Linear(xvec_size, input_size)
else:
self.xvec_proj = None
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.sos = vocab_size - 2
self.eos = vocab_size - 1
self.length_regulator_conf = kwargs.get("length_regulator_conf", None)
if self.length_regulator_conf is not None:
self.length_regulator = self.build_length_regulator()
else:
self.length_regulator = None
def build_encoder(self, encoder_conf: dict):
if encoder_conf is None:
assert hasattr(self, "encoder_conf"), \
"function param encoder_conf is None and model doesn't has encoder_conf attribute either."
encoder_conf = self.encoder_conf
encoder_name = encoder_conf.pop("name", "transformer")
model = None
if encoder_name == "transformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
use_cnn_module=False,
macaron_style=False,
)
elif encoder_name == "conformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**encoder_conf,
input_size=self.input_size,
)
elif encoder_name == "upsampling_conformer":
from funasr.models.llm_asr.tts_models.encoders import UpsampleConformerEncoder
model = UpsampleConformerEncoder(
**encoder_conf,
input_size=self.input_size,
)
encoder_conf["name"] = encoder_name
return model
def build_length_regulator(self):
name = self.length_regulator_conf.pop("name", None)
model = None
if name == "upsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator
model = UpSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "downsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator
model = DownSamplingRegulator(self.input_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "interpolate":
from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator
model = InterpolateRegulator(self.input_size, **self.length_regulator_conf)
elif name == "upsampling_cif":
from funasr.models.llm_asr.diffusion_models.length_regulator import UpsamplingCifRegulator
model = UpsamplingCifRegulator(self.input_size, **self.length_regulator_conf)
self.length_regulator_conf["name"] = name
return model
@staticmethod
def norm_and_sample_xvec(xvec, xvec_lengths):
xvec_list = []
for i, ilen in enumerate(xvec_lengths):
idx = random.randint(0, ilen - 1)
while torch.any(~torch.isfinite(xvec[i, idx])):
idx = random.randint(0, ilen - 1)
xvec_list.append(xvec[i, idx])
rand_xvec = torch.vstack(xvec_list)
rand_xvec = F.normalize(rand_xvec, dim=1)
return rand_xvec
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.decoder.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def model_forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
# 0. Up-sampling text length
if self.length_regulator is not None:
text, text_lengths = self.length_regulator(text, text_lengths)
# 1. padding xvec
if xvec is not None and self.xvec_proj is not None:
xvec = xvec[:, :xvec_lengths.max()]
# random select a xvec from xvec matrix
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
xvec = self.xvec_proj(xvec)
text = text + xvec.unsqueeze(1)
hint_once("use xvec", "use_xvec")
# 1. Encoder
encoder_out, encoder_out_lens, _ = self.encoder(text, text_lengths)
return encoder_out, encoder_out_lens
def predictor(
self,
am: torch.Tensor,
am_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
alignment,
):
acoustic_embeds = []
use_pred_num = 0
for am_xs, enc_len, ali, y, y_lens in zip(am, am_lens, alignment, ys_pad, ys_pad_lens):
pred = itertools.groupby(ali[:enc_len])
acoustic_embed = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0:
acoustic_embed.append(torch.mean(am_xs[_start:_end, :], 0, keepdim=True))
_start = _end
if len(acoustic_embed) != y_lens:
acoustic_embeds.append(y[:y_lens])
else:
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
use_pred_num += 1
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
return acoustic_embeds, use_pred_num / am.shape[0]
def force_align_text(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
):
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs
)
log_probs = self.ctc.log_softmax(encoder_out)
with torch.no_grad():
alignment = ctc_forced_align(
log_probs.float(),
speech.long(),
encoder_out_lens.long(),
speech_lengths.long(),
ignore_id=self.ignore_id,
)
aligned_token_emb, use_pred_ratio = self.predictor(
encoder_out, encoder_out_lens,
self.token_embedding(speech), speech_lengths,
alignment,
)
loss = 0
states = dict(
use_pred_ratio=use_pred_ratio,
)
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_ctc"] = loss_ctc.item()
loss = loss + self.ctc_weight * loss_ctc
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_att"] = loss_att.item()
loss = loss + (1.0 - self.ctc_weight) * loss_att
states["loss"] = loss.item()
return loss, aligned_token_emb, states
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
logits = self.ctc.log_softmax(encoder_out)
return loss_ctc, logits
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...), speech tokens
speech_lengths: (Batch, )
text: (Batch, Length), text tokens
text_lengths: (Batch, )
xvec: (Batch, Length, ...) x-vectors
xvec_lengths: (Batch, )
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, : speech_lengths.max()]
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
# embed text inputs
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs,
)
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict(
batch_size=float(batch_size),
text_len=float(text.shape[1]),
enc_len=float(encoder_out.shape[1]),
speech_len=float(speech.shape[1]),
token_text_ratio=float(speech.shape[1])/float(text.shape[1]),
)
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def topp_sampling(self, probs, top_p=0.8):
sorted_value, sorted_idx = probs.sort(descending=True, stable=True)
cumulative_probs = torch.cumsum(sorted_value, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_idx[sorted_indices_to_remove]
probs[indices_to_remove] = 0
top_ids = torch.multinomial(probs, num_samples=1)
return top_ids
def sampling_ids(self, enc_outs, sampling="greedy", blank_penalty=None, return_probs=False):
probs = self.ctc.softmax(enc_outs)
if blank_penalty > 0:
probs[:, :, 0] = probs[:, :, 0] * blank_penalty
# top-p sampling
if "." in sampling:
sampling = float(sampling)
tokens = self.topp_sampling(probs, top_p=sampling)
tokens = torch.tensor(tokens, dtype=torch.long).to(probs.device)
# top-k sampling
elif sampling.isdigit():
sampling = int(sampling)
probs = probs.topk(sampling)
tokens = probs.multinomial(1, replacement=True)
else:
if sampling == "greedy":
tokens = torch.argmax(probs, dim=-1)
elif "threshold_" in sampling:
threshold = float(sampling.split("_")[1])
hint_once(f"Decoding mode: blank threshold={threshold:.2f}", "decoding_mode")
# mask out blank according to threshold
mask = probs[:, :, 0] > threshold
probs[:, :, 0] = probs[:, :, 0] * mask
tokens = torch.argmax(probs, dim=-1)
else:
raise NotImplementedError(f"sampling method {sampling} not implemented")
if not return_probs:
return tokens
return tokens, probs
def inference(
self,
text: torch.Tensor, text_lengths: torch.Tensor,
xvec=None, xvec_lengths=None,
sampling="greedy",
blank_penalty: float = 0.0,
text_is_embedding=False,
return_hidden=False,
**kwargs,
):
device = text.device
# use casual mode at inference stage
self.encoder.use_causal_prob = kwargs.get("use_causal_prob", 1.0)
hint_once(f"use_causal_prob {self.encoder.use_causal_prob}.", "use_causal_prob")
# embed text inputs
if not text_is_embedding:
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
# 1. Encoder
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
)
fa_tokens, enc_probs = self.sampling_ids(
encoder_out,
sampling=sampling,
blank_penalty=blank_penalty,
return_probs=True,
)
# remove blanks (id=0) and convert token ids into the original format
tokens = [[x-1] for x in fa_tokens[0].cpu().tolist() if x > 0]
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
if not return_hidden:
return tokens
acoustic_embs, acoustic_emb_lens = [], []
for idx, (prob, enc) in enumerate(zip(enc_probs, encoder_out)):
pred = itertools.groupby(prob.argmax(-1).cpu())
acs_emb = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0 and pred_token != -1:
acs_emb.append(torch.mean(enc[_start:_end, :], 0, keepdim=True))
_start = _end
acs_emb = torch.cat(acs_emb, dim=0)
acoustic_embs.append(acs_emb)
acoustic_emb_lens.append(acs_emb.shape[0])
acoustic_embs = pad_list(acoustic_embs, 0.0)
acoustic_emb_lens = torch.tensor(acoustic_emb_lens, dtype=torch.int64, device=device)
return (tokens, fa_tokens), acoustic_embs, acoustic_emb_lens
class NARCTCProbModel(NARCTCModel):
def __init__(self, input_size: int, vocab_size: int, encoder: Union[nn.Module, dict],
decoder: Optional[nn.Module] = None, ctc_weight: float = 0.5, ignore_id: int = -1,
lsm_weight: float = 0.0, length_normalized_loss: bool = False, **kwargs):
super().__init__(input_size, vocab_size, encoder, decoder, ctc_weight, ignore_id, lsm_weight,
length_normalized_loss, **kwargs)
def predictor(
self,
am_probs: torch.Tensor,
am_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
alignment,
):
acoustic_embeds = []
use_pred_num = 0
for probs, enc_len, ali, y, y_lens in zip(am_probs, am_lens, alignment, ys_pad, ys_pad_lens):
pred = itertools.groupby(ali[:enc_len])
acoustic_embed = []
_start = 0
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0:
acoustic_embed.append(torch.mean(probs[_start:_end, :], 0, keepdim=True))
_start = _end
if len(acoustic_embed) != y_lens:
acoustic_embeds.append(F.one_hot(y[:y_lens], self.vocab_size).float())
else:
acoustic_embeds.append(torch.cat(acoustic_embed, dim=0))
use_pred_num += 1
acoustic_embeds[-1] = torch.matmul(acoustic_embeds[-1], self.token_embedding.weight)
acoustic_embeds = pad_sequence(acoustic_embeds, batch_first=True, padding_value=0)
return acoustic_embeds, use_pred_num / am_probs.shape[0]
def force_align_text(self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None, **kwargs):
# plus one to speech token, to make index 0 represent <blank>,
# decoder vocab must be: 1 (blank) + num of token + 1 (sos) + 1 (eos)
speech = torch.where(speech != -1, speech + 1, speech)
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
**kwargs
)
log_probs = self.ctc.log_softmax(encoder_out)
with torch.no_grad():
alignment = ctc_forced_align(
log_probs.float(),
speech.long(),
encoder_out_lens.long(),
speech_lengths.long(),
ignore_id=self.ignore_id,
)
aligned_token_emb, use_pred_ratio = self.predictor(
log_probs.float(), encoder_out_lens.long(),
speech.long(), speech_lengths.long(),
alignment,
)
loss = 0
states = dict(
use_pred_ratio=use_pred_ratio,
)
if self.ctc_weight != 0.0:
loss_ctc, logits = self._calc_ctc_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_ctc"] = loss_ctc.item()
loss = loss + self.ctc_weight * loss_ctc
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, speech, speech_lengths
)
states["loss_att"] = loss_att.item()
loss = loss + (1.0 - self.ctc_weight) * loss_att
states["loss"] = loss.item()
return loss, aligned_token_emb, states
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor, xvec=None, xvec_lengths=None, sampling="greedy",
blank_penalty: float = 0.0, text_is_embedding=False, return_hidden=False, **kwargs):
device = text.device
# embed text inputs
if not text_is_embedding:
mask = (text != -1).float().unsqueeze(-1)
text = self.text_embedding(torch.clamp(text, min=0)) * mask
# 0. Up-sampling text length
if self.length_regulator is not None:
text, text_lengths = self.length_regulator(text, text_lengths)
# 1. padding xvec
if xvec is not None and self.xvec_proj is not None:
xvec = xvec[:, :xvec_lengths.max()]
# random select a xvec from xvec matrix
xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
xvec = self.xvec_proj(xvec)
text = text + xvec.unsqueeze(1)
hint_once("use xvec", "use_xvec")
# 1. Encoder
encoder_out, encoder_out_lens = self.model_forward(
text, text_lengths,
xvec, xvec_lengths,
)
tokens, enc_probs = self.sampling_ids(
encoder_out,
sampling=sampling,
blank_penalty=blank_penalty,
return_probs=True,
)
# remove blanks (id=0) and convert token ids into the original format
tokens = [[x - 1] for x in tokens[0].cpu().tolist() if x > 0]
tokens = torch.tensor([tokens], dtype=torch.int64, device=device)
if not return_hidden:
return tokens
acoustic_embs = self.token_embedding(tokens.squeeze(-1))
acoustic_emb_lens = torch.tensor([acoustic_embs.shape[1]], dtype=torch.int64, device=device)
return tokens, acoustic_embs, acoustic_emb_lens

View File

@ -0,0 +1,761 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from torch import nn
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.lightconv import LightweightConvolution
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
from funasr.models.transformer.utils.mask import subsequent_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, 1, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, 1, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
causal=True,
):
super().__init__()
attention_dim = encoder_output_size
self.causal = causal
self.vocab_size = vocab_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
# Must set by the inheritance
self.decoders = None
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
if self.causal:
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
causal: bool = True,
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
causal=causal,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
embeds_id: int = -1,
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.embeds_id = embeds_id
self.attention_dim = attention_dim
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask
)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
if embeds_outputs is not None:
return x, olens, embeds_outputs
else:
return x, olens
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)

View File

@ -0,0 +1,346 @@
@
@
@
@
@!
@"
@#
@$
@'
@(
@)
@*
@,
@-
@.
@/
@:
@;
@<
@>
@?
@[
@]
@^
@_
@`
@a_c1
@a_c2
@a_c3
@a_c4
@a_c5
@aa0
@aa1
@aa2
@ae0
@ae1
@ae2
@ah0
@ah1
@ah2
@ai_c1
@ai_c2
@ai_c3
@ai_c4
@ai_c5
@an_c1
@an_c2
@an_c3
@an_c4
@an_c5
@ang_c1
@ang_c2
@ang_c3
@ang_c4
@ang_c5
@ao0
@ao1
@ao2
@ao_c1
@ao_c2
@ao_c3
@ao_c4
@ao_c5
@aw0
@aw1
@aw2
@ay0
@ay1
@ay2
@b
@b_c
@c_c
@ch
@ch_c
@d
@d_c
@dh
@e_c1
@e_c2
@e_c3
@e_c4
@e_c5
@eh0
@eh1
@eh2
@ei_c1
@ei_c2
@ei_c3
@ei_c4
@ei_c5
@en_c1
@en_c2
@en_c3
@en_c4
@en_c5
@eng_c1
@eng_c2
@eng_c3
@eng_c4
@eng_c5
@er0
@er1
@er2
@er_c1
@er_c2
@er_c3
@er_c4
@er_c5
@ey0
@ey1
@ey2
@f
@f_c
@g
@g_c
@ga
@ge
@go
@h_c
@hh
@i_c1
@i_c2
@i_c3
@i_c4
@i_c5
@ia_c1
@ia_c2
@ia_c3
@ia_c4
@ia_c5
@ian_c1
@ian_c2
@ian_c3
@ian_c4
@ian_c5
@iang_c1
@iang_c2
@iang_c3
@iang_c4
@iang_c5
@iao_c1
@iao_c2
@iao_c3
@iao_c4
@iao_c5
@ie_c1
@ie_c2
@ie_c3
@ie_c4
@ie_c5
@ih0
@ih1
@ih2
@ih_c1
@ih_c2
@ih_c3
@ih_c4
@ih_c5
@ii_c1
@ii_c2
@ii_c3
@ii_c4
@ii_c5
@in_c1
@in_c2
@in_c3
@in_c4
@in_c5
@ing_c1
@ing_c2
@ing_c3
@ing_c4
@ing_c5
@iong_c1
@iong_c2
@iong_c3
@iong_c4
@iong_c5
@iou_c1
@iou_c2
@iou_c3
@iou_c4
@iou_c5
@iy0
@iy1
@iy2
@j_c
@jh
@k
@k_c
@l
@l_c
@m
@m_c
@n
@n_c
@ng
@o_c1
@o_c2
@o_c3
@o_c4
@o_c5
@ong_c1
@ong_c2
@ong_c3
@ong_c4
@ong_c5
@ou_c1
@ou_c2
@ou_c3
@ou_c4
@ou_c5
@ouh
@ouj
@oull
@ouw
@ow0
@ow1
@ow2
@oy0
@oy1
@oy2
@p
@p_c
@q_c
@r
@r_c
@s
@s_c
@sh
@sh_c
@t
@t_c
@th
@u_c1
@u_c2
@u_c3
@u_c4
@u_c5
@ua_c1
@ua_c2
@ua_c3
@ua_c4
@ua_c5
@uai_c1
@uai_c2
@uai_c3
@uai_c4
@uai_c5
@uan_c1
@uan_c2
@uan_c3
@uan_c4
@uan_c5
@uang_c1
@uang_c2
@uang_c3
@uang_c4
@uang_c5
@uei_c1
@uei_c2
@uei_c3
@uei_c4
@uei_c5
@uen_c1
@uen_c2
@uen_c3
@uen_c4
@uen_c5
@uh0
@uh1
@uh2
@uo_c1
@uo_c2
@uo_c3
@uo_c4
@uo_c5
@uw0
@uw1
@uw2
@v
@v_c1
@v_c2
@v_c3
@v_c4
@v_c5
@van_c1
@van_c2
@van_c3
@van_c4
@van_c5
@ve_c1
@ve_c2
@ve_c3
@ve_c4
@ve_c5
@vn_c1
@vn_c2
@vn_c3
@vn_c4
@vn_c5
@w
@w_c
@xx_c
@y
@y_c
@z
@z_c
@zh
@zh_c
@{
@|
@}
@~
@—
@——
@
@
@“
@”
@…
@……
@‰
@℃
@
@○
@、
@。
@《
@》
@『
@』
@【
@】
@
@
@
@
@
@
@
@
@
@
@¥

View File

@ -0,0 +1,31 @@
from pathlib import Path
from typing import Iterable
from typing import Union
def build_tokenizer(
token_type: str,
bpemodel: Union[Path, str, Iterable[str]] = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
p_word2phn: float = 0.5,
):
if "whisper_rich_ttsfrd" in token_type:
from funasr.models.llm_asr.tts_text_tokenizer.whisper_tokenizer import WhisperRichTtsFrdTokenizer
return WhisperRichTtsFrdTokenizer(
token_path="multilingual_zh_ja_yue_char_del",
num_languages=105,
task=None,
language=None,
ttsfrd_type="ttsfrd_rich",
ttsfrd_model=bpemodel,
p_word2phn=p_word2phn,
)
else:
raise ValueError(
f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
)

View File

@ -0,0 +1,175 @@
import logging
from pathlib import Path
import re
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import warnings
import os
import json
import jamo
class TtsFrdRich:
"""
rich text info: phoneme + puncs + boundary + [word2phone]
"""
def __init__(self, remove_boundary=True, token_type="pronplus"):
super().__init__()
self.remove_boundary = remove_boundary
self.token_type = token_type
self.g2p = None
self.lang_type = None
self.lang_type_map = {"zh-cn": "pinyin", "en-us": "enus"}
@staticmethod
def contains_chinese(str):
return bool(re.search(r'[\u4e00-\u9fff]', str))
@staticmethod
def is_full_half_punctuation_string(s):
# 包含ASCII标点和常见全角标点
punctuation_pattern = r'[\u0000-\u002f\u003a-\u0040\u005b-\u0060\u007b-\u007f\u3000-\u303f\uff00-\uffef]'
# 使用re.findall找出所有匹配的字符
results = re.findall(punctuation_pattern, s)
# 如果字符串长度和匹配到的字符总数一样,说明全部是标点
return len(s) == len("".join(results))
def build(self, resource_dir, lang_type="Zh-CN"):
lang_type = lang_type.lower()
new_lang_type = self.lang_type_map[lang_type]
if self.g2p is None:
import ttsfrd
assert os.path.isdir(resource_dir)
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resource_dir)
self.g2p = fe
# self.lang_type = new_lang_type
self.set_lang_type(new_lang_type)
if self.lang_type != new_lang_type:
# self.lang_type = new_lang_type
self.set_lang_type(new_lang_type)
def set_lang_type(self, lang_type):
if lang_type == "enus":
self.g2p.set_lang_type(lang_type)
self.g2p.enable_pinyin_mix(True)
# self.g2p.set_breakmodel_index(0)
else:
self.g2p.set_lang_type(lang_type)
self.g2p.enable_pinyin_mix(True)
# self.g2p.set_breakmodel_index(1)
self.lang_type = lang_type
def set_token_type(self, token_type):
assert token_type in ["pronplus", "word2phn", "wordlist"], token_type
self.token_type = token_type
def __call__(self, text) -> Union[List[str], str]:
assert self.g2p is not None
if not self.contains_chinese(text):
if self.lang_type != "enus":
self.set_lang_type("enus")
else:
if self.lang_type != "pinyin":
self.set_lang_type("pinyin")
if self.token_type == "word2phn":
return self._get_word2phn(text)
elif self.token_type == "pronplus":
return self._get_pronplus(text)
elif self.token_type == "wordlist":
return self._get_wordlist(text)
else:
raise ValueError(f"only type: [pronplus, word2phn, wordlist] supported, now type: {self.token_type}")
def _get_pronplus(self, text) -> List[str]:
pronplus = self.g2p.get_frd_extra_info(text, 'pronplus')
if self.remove_boundary:
pronplus = pronplus.replace("/", "") # word boundary
pronplus = pronplus.replace("#", "") # syllable boundary
# pronplus = pronplus.replace("\n", "")
pronplus = pronplus.replace("\n", " ")
symbols: List[str] = []
for pron in pronplus.split(" "):
pron = pron.strip().lower()
if pron and pron[0].isalpha():
symbols.append(pron)
else:
symbols.extend([mark for mark in pron if mark])
return symbols
def text2tokens(self, line: str) -> List[str]:
json_str = self._get_word2phn(line)
data = json.loads(json_str)
retval = []
for one in data["word2phn"]:
for key, value in one.items():
if value is not None:
retval.extend([f"@{x}" for x in value])
else:
if key == " ":
key = "<|space|>"
retval.append(f"@{key}")
return retval
def tokens2text(self, tokens: Iterable[str]) -> str:
pass
def _get_wordlist(self, text) -> str:
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
return wordlist
def _get_word2phn(self, text) -> str:
wordlist = self.g2p.get_frd_extra_info(text, 'wordlist')
wordlist_subs = wordlist.split("\n")
word2phn_info = []
prev_word_type = None
prev_word = None
for json_str in wordlist_subs:
if len(json_str) == 0:
continue
wordlist_info = json.loads(json_str)["wordlist"]
for word_info in wordlist_info:
is_english_word = True
this_phone_list = None
if word_info["syllables"] is None:
# punctuation
this_word_type = "punc"
pass
elif self.is_full_half_punctuation_string(word_info["name"]):
# punctuation, handle some g2p's mistakes spelling punctuation!!!
this_word_type = "punc"
pass
else:
this_phone_list = []
for syllable_info in word_info["syllables"]:
phn_count = syllable_info["phone_count"]
syllable_phone_list = syllable_info["pron_text"].split(" ")
assert len(syllable_phone_list) == phn_count, len(syllable_phone_list)
if "py_text" in syllable_info:
# chinese add tone info
syllable_phone_list[-1] = syllable_phone_list[-1]+str(syllable_info["tone"])
is_english_word = False
this_phone_list += syllable_phone_list
if is_english_word:
this_word_type = "en_word"
else:
this_word_type = "ch_word"
if this_word_type == "en_word":
if prev_word_type is None:
pass
elif prev_word_type == "en_word":
word2phn_info.append({" ": None})
elif prev_word_type == "punc":
if (prev_word not in ["\"", "\'", "(", "", "[", ""] and
prev_word.split(" ")[-1] not in ["\"", "\'", "(", "", "[", ""]):
word2phn_info.append({" ": None})
elif prev_word_type == "ch_word":
word2phn_info.append({" ": None})
elif this_word_type == "ch_word":
if prev_word_type is not None and prev_word_type == "en_word":
word2phn_info.append({" ": None})
elif this_word_type == "punc":
if word_info["name"] in ["("]:
word2phn_info.append({" ": None})
this_word2phn_dict = {word_info["name"]: this_phone_list}
word2phn_info.append(this_word2phn_dict)
prev_word_type = this_word_type
prev_word = list(word2phn_info[-1].keys())[0]
return json.dumps({"raw": text, "word2phn": word2phn_info}, ensure_ascii=False)

View File

@ -0,0 +1,462 @@
import base64
import os
import string
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
"minnan": "minnan",
"wuyu": "wuyu",
"dialect": "dialect",
"zh/en": "zh/en",
"en/zh": "en/zh",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
AUDIO_EVENT = {
"ASR": "ASR",
"AED": "AED",
"SER": "SER",
"Speech": "Speech",
"/Speech": "/Speech",
"BGM": "BGM",
"/BGM": "/BGM",
"Laughter": "Laughter",
"/Laughter": "/Laughter",
"Applause": "Applause",
"/Applause": "/Applause",
}
EMOTION = {
"HAPPY": "HAPPY",
"SAD": "SAD",
"ANGRY": "ANGRY",
"NEUTRAL": "NEUTRAL",
}
TTS_Vocal_Token = {
"TTS/B": "TTS/B",
"TTS/O": "TTS/O",
"TTS/Q": "TTS/Q",
"TTS/A": "TTS/A",
"TTS/CO": "TTS/CO",
"TTS/CL": "TTS/CL",
"TTS/H": "TTS/H",
"endofprompt": "endofprompt",
"sil": "sil",
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(3, 14)}
}
@dataclass
class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
return self.encoding.decode(token_ids, **kwargs)
def get_vocab_size(self) -> int:
return self.encoding.n_vocab
@cached_property
def eot(self) -> int:
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self) -> int:
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self) -> int:
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self) -> int:
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self) -> int:
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self) -> int:
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99, ttsfrd_name: Optional[str] = None):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
if name == "gpt2" or name == "multilingual":
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
else:
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
if ttsfrd_name is not None:
ttsfrd_vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{ttsfrd_name}.token")
assert os.path.isfile(ttsfrd_vocab_path), f"{ttsfrd_vocab_path} missing"
with open(ttsfrd_vocab_path, "r") as fr:
specials.extend([f"<|{line.strip()}|>" for line in fr if line])
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
encoding_path: Optional[str] = None,
ttsfrd_name: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
if encoding_path is not None:
encoding_name = encoding_path
encoding = get_encoding(name=encoding_name, num_languages=num_languages, ttsfrd_name=ttsfrd_name)
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)

View File

@ -0,0 +1,164 @@
import copy
import json
import os
import random
import re
from typing import Iterable, List, Union
import numpy as np
class WhisperRichTtsFrdTokenizer:
def __init__(
self,
token_path: str,
num_languages: int,
task: str = None,
language: str = None,
ttsfrd_type: str = None,
p_word2phn: float = 0.5,
ttsfrd_model: str = None,
):
import funasr.models.llm_asr.tts_text_tokenizer.voice_echo_rich_tokenizer as tokenizer
self.token_path = token_path
self.num_languages = num_languages
self.language = language
self.task = task
self.ttsfrd_type = ttsfrd_type
self.p_word2phn = p_word2phn
# print('token_path:',token_path)
if token_path == "whisper_en" or token_path == "whisper_gpt2" or token_path == "gpt2":
self.tokenizer = tokenizer.get_tokenizer(multilingual=False, num_languages=num_languages)
elif token_path == "whisper_multilingual" or token_path == "multilingual":
self.tokenizer = tokenizer.get_tokenizer(
multilingual=True, language=self.language, task=self.task, num_languages=num_languages
)
else:#
self.tokenizer = tokenizer.get_tokenizer(
multilingual=True, language=self.language, task=self.task, num_languages=num_languages,
encoding_path=token_path, ttsfrd_name=ttsfrd_type
)
if ttsfrd_model is not None and os.path.isdir(ttsfrd_model):
from funasr.models.llm_asr.tts_text_tokenizer.phoneme_tokenizer import TtsFrdRich
self.ttsfrd_tokenizer = TtsFrdRich(remove_boundary=True, token_type="word2phn")
self.ttsfrd_tokenizer.build(ttsfrd_model)
else:
self.ttsfrd_tokenizer = None
# self.tokenizer = copy.deepcopy(self.tokenizer)
def text_mixing(self, line: str) -> str:
try:
data_info = json.loads(line)
# ttsfrd_word2phn info
if isinstance(data_info, dict) and "raw" in data_info and "word2phn" in data_info:
raw_text = data_info["raw"]
ttsfrd_word2phn = data_info["word2phn"]
if random.random() < self.p_word2phn:
ret_text = ""
for ttsfrd_word in ttsfrd_word2phn:
for word_str, phn_list in ttsfrd_word.items():
if phn_list is not None:
if random.random() < self.p_word2phn:
ret_text = ret_text + "".join([f"<|@{p}|>" for p in phn_list])
else:
ret_text += word_str
else:
ret_text += word_str
else:
ret_text = raw_text
else:
ret_text = line
except json.JSONDecodeError:
ret_text = line
return ret_text
def get_num_vocabulary_size(self) -> int:
return self.tokenizer.get_vocab_size()
def text2ids(self, line: str, language: str) -> List[int]:
language_tok = "<|" + language + "|>"
assert language_tok in self.tokenizer.special_tokens, "Language token not found, lang: {}, line: {}".format(language_tok, line)
# line = re.sub(r'<(\d+\.\d+)>', r'<|\1|>', line)
pattern = re.compile(r'<|(\d+\.\d+)|>')
with_timestamps = pattern.search(line)
if with_timestamps:
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
allowed_special = set([f"<|{i * 0.02:.2f}|>" for i in range(1501)])
encoded_line = self.tokenizer.encode(line, allowed_special=allowed_special)
else:
sot_tok = [self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
encoded_line = self.tokenizer.encode(line)
return sot_tok + encoded_line
def ids2text(self, integers: Union[np.ndarray, Iterable[int]]) -> str:
return self.tokenizer.decode_with_timestamps(integers)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
return [self.tokenizer.decode_with_timestamps([i]) for i in integers]
def text2tokens(self, line: str, endofprompt="<|endofprompt|>", sil="<|sil|>") -> List[str]:
# keep prompt and sil unchanged
prompt_text = ""
st_sil, ed_sil = False, False
if endofprompt in line:
pos = line.find(endofprompt)
prompt_text = line[:pos+len(endofprompt)]
line = line[pos+len(endofprompt):]
if line.startswith(sil):
line = line[len(sil):]
st_sil = True
if line.endswith(sil):
line = line[:-len(sil)]
ed_sil = True
# token to phone and mixup
if self.ttsfrd_tokenizer is not None:
line = self.ttsfrd_tokenizer(line)
if self.ttsfrd_type is not None:
line = self.text_mixing(line)
# add prompt text and sil back
if st_sil:
line = sil + line
if ed_sil:
line = line + sil
line = prompt_text + line
return self.tokenizer.encode(line, allowed_special="all")
def tokens2text(self, tokens: Iterable[str]) -> str:
return self.tokenizer.decode_with_timestamps(tokens)
# def get_sot(self, sot_template: str, lang: str = None) -> List[int]:
# if lang is not None:
# lang = lang.replace("<", "").replace(">", "").replace("|", "")
# sot = sot_template.replace("LANG", lang)
# else:
# if "<|LANG|>" in sot_template:
# sot = sot_template.split("<|LANG|>", 1)[0]
# else:
# sot = sot_template
# sot_tok = self.tokenizer.encode(sot, allowed_special="all")
# return sot_tok
def get_sot(self, language: str = None, with_timestamps: bool = False) -> List[int]:
if language is not None:
language_tok = "<|" + language + "|>"
assert language_tok in self.tokenizer.special_tokens
if with_timestamps:
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe]
else:
sot_tok = [self.tokenizer.sot, self.tokenizer.special_tokens.get(language_tok), self.tokenizer.transcribe, self.tokenizer.no_timestamps]
else:
sot_tok = [self.tokenizer.sot]
return sot_tok
def get_all_languages(self) -> List[str]:
return list(self.tokenizer.all_language_codes)
def __repr__(self):
return (
f"{self.__class__.__name__}(model_type={self.token_path}, "
f"language={self.language}, ttsfrd={self.ttsfrd_type})"
)

View File

@ -50,3 +50,25 @@ def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
sub_corner = torch.zeros(vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
ret[0 : vad_pos - 1, vad_pos:] = sub_corner
return ret
def causal_block_mask(size, block_size=1, device="cpu", dtype=torch.bool):
"""Create mask for subsequent steps (size, size).
:param int size: size of mask
:param int block_size: block size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> causal_block_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
assert size % block_size == 0
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, block_size, rounding_mode='trunc') + 1) * block_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret.to(dtype)