refine decoding process, merge flow and vocoder

This commit is contained in:
志浩 2024-07-10 11:17:33 +08:00
parent e49e54596c
commit 573ae881cd
18 changed files with 5166 additions and 44 deletions

View File

@ -0,0 +1,628 @@
# Copyright 2020 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer encoder definition."""
import logging
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import nn
from funasr.models.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.models.transformer.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.utils.nets_utils import get_activation
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.mask import subsequent_mask
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError,
check_short_utt, Conv2dSubsamplingPad
)
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
self.feed_forward_macaron(x)
)
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
self.feed_forward(x)
)
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
class ConformerEncoder(nn.Module):
"""Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
causal: bool = False,
skip: bool = False,
channel_first: bool = False,
):
super().__init__()
self._output_size = output_size
self.causal = causal
self.skip = skip
self.channel_first = channel_first
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if selfattention_layer_type == "rel_selfattn":
selfattention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert selfattention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2dpad":
self.embed = Conv2dSubsamplingPad(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
f"should be equal to num_blocks ({num_blocks})"
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate[lnum],
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
prev_states: torch.Tensor = None,
ctc = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
raw_input = xs_pad
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if ilens is not None:
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
else:
masks = torch.ones(xs_pad.shape[0], 1, xs_pad.shape[1],
dtype=torch.bool, device=xs_pad.device)
if self.causal:
causal_mask = subsequent_mask(
xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype
).unsqueeze(0)
masks = masks & causal_mask
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
or isinstance(self.embed, Conv2dSubsamplingPad)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
if isinstance(xs_pad, tuple):
x, pos_emb = xs_pad
x = x + self.conditioning_layer(ctc_out)
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.channel_first:
xs_pad = xs_pad.permute(0, 2, 1)
if self.skip:
xs_pad = xs_pad + raw_input
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
if ilens is not None:
return xs_pad, olens, None
else:
return xs_pad

View File

@ -0,0 +1,178 @@
from abc import ABC
import torch
import torch.nn.functional as F
from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, ConditionalDecoder)
import logging
from funasr.utils.hinter import hint_once
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats: int,
cfm_params: dict,
n_spks: int = 1,
spk_emb_dim: int = 128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.get("solver", "euler")
self.sigma_min = cfm_params.get("sigma_min", 1e-4)
self.estimator = None
self.t_scheduler = cfm_params.get("t_scheduler", "linear")
self.training_cfg_rate = cfm_params.get("training_cfg_rate", 0.0)
self.inference_cfg_rate = cfm_params.get("inference_cfg_rate", 0.0)
self.reg_loss_type = cfm_params.get("reg_loss_type", "l2")
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
steps = 1
while steps <= len(t_span) - 1:
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
torch.zeros_like(cond)
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if steps < len(t_span) - 1:
dt = t_span[steps + 1] - t
steps += 1
return sol[-1]
def calc_reg_loss(self, prediction, target, loss_mask):
if self.reg_loss_type == 'l1':
hint_once("use l1 loss to train CFM", "CFM_LOSS_L1")
l1_loss = F.l1_loss(prediction, target, reduction="none")
l1_loss = l1_loss * loss_mask
return l1_loss
elif self.reg_loss_type == 'l2':
hint_once("use l2 loss to train CFM", "CFM_LOSS_L2")
l2_loss = F.mse_loss(prediction, target, reduction="none")
l2_loss = l2_loss * loss_mask
return l2_loss
else:
hint_once("use l1+l2 loss to train CFM", "CFM_LOSS_L1_L2")
l1_loss = F.l1_loss(prediction, target, reduction="none")
l1_loss = l1_loss * loss_mask
l2_loss = 0.5 * F.mse_loss(prediction, target, reduction="none")
l2_loss = l2_loss * loss_mask
return l1_loss * 0.5 + l2_loss * 0.5
def compute_loss(self, x1, mask, mu, spks=None, cond=None, reduction='none'):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
if self.training_cfg_rate > 0:
cfg_mask = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) > self.training_cfg_rate
mu = mu * cfg_mask
if spks is not None:
spks = spks * cfg_mask.squeeze(-1)
if cond is not None:
cond = cond * cfg_mask
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = self.calc_reg_loss(pred, u, mask)
if reduction == "mean":
loss = loss.sum() / (torch.sum(mask) * u.shape[1])
return loss, y
class CFM(BASECFM):
def __init__(self, in_channels, out_channel, cfm_params, decoder_params,
n_spks=1, spk_emb_dim=64, decoder_name="Decoder"):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
if decoder_name == "Decoder":
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
else:
self.estimator = ConditionalDecoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)

View File

@ -0,0 +1,219 @@
import torch
from typing import Tuple
import torch.nn as nn
from torch.nn import functional as F
from funasr.models.llm_asr.diffusion_models.matcha_decoder import Upsample1D, Downsample1D
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from einops import repeat, pack
import logging
class UpSamplingRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for ratio in sampling_ratios:
if ratio > 1:
module = Upsample1D(channels=channels, channel_first=False)
else:
module = nn.Linear(channels, channels)
norm = nn.LayerNorm(channels)
act = nn.LeakyReLU()
model.extend([module, norm, act])
model.append(
nn.Linear(channels, out_channels)
)
self.model = nn.Sequential(*model)
def forward(self, x, xlens, y=None, y_lens=None, cond=None):
# x, out, y in (B, T, D)
out = self.model(x)
out = out[:, :y.shape[1]]
olens = y_lens
return out, olens
class DownSamplingRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for ratio in sampling_ratios:
if ratio > 1:
module = Downsample1D(dim=channels, channel_first=False, padding=2)
else:
module = nn.Linear(channels, channels)
norm = nn.LayerNorm(channels)
act = nn.LeakyReLU()
model.extend([module, norm, act])
model.append(
nn.Linear(channels, out_channels)
)
self.model = nn.Sequential(*model)
def forward(self, x, xlens, y=None, y_lens=None, cond=None):
# x, out, y in (B, T, D)
out = self.model(x)
out = out[:, :y.shape[1]]
olens = y_lens
return out, olens
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
mode="nearest",
align_corners=False,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
self.mode = mode
self.align_corners = align_corners
def forward(self, x, xlens, y=None, ylens=None, cond=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
align_corners_opt = {}
if self.mode in ["linear", "bilinear","bicubic", "trilinear"]:
align_corners_opt = dict(align_corners=self.align_corners)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=y.shape[1],
mode=self.mode, **align_corners_opt)
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens
class UpsamplingBlock(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
stride=2,
groups=1,
channel_first=False,
):
super().__init__()
self.channel_first = channel_first
self.stride = stride
self.up_conv = nn.ConvTranspose1d(in_channels, channels, stride * 2, stride, 1)
self.block1 = torch.nn.Sequential(
torch.nn.Conv1d(channels, channels, 3, padding=1),
torch.nn.GroupNorm(groups, channels),
nn.Mish(),
)
self.block2 = torch.nn.Sequential(
torch.nn.Conv1d(channels, channels, 3, padding=1),
torch.nn.GroupNorm(groups, channels),
nn.Mish(),
)
self.res_conv = torch.nn.Conv1d(channels, channels, 1)
def forward(self, x, ilens):
if not self.channel_first:
x = x.transpose(1, 2)
olens = ilens * self.stride
o_masks = (~make_pad_mask(olens))[:, None, :].to(x)
res = out = self.up_conv(x) * o_masks
out = self.block1(out) * o_masks + out
out = self.block2(out) * o_masks + out
out = out + self.res_conv(res) * o_masks
if not self.channel_first:
out = out.transpose(1, 2)
return out, olens
class RepeatLengthRegulator(torch.nn.Module):
"""Repeat Length regulator module for feed-forward Transformer.
This is a module of length regulator described in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
The length regulator expands char or
phoneme-level embedding features to frame-level by repeating each
feature based on the corresponding predicted durations.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, pad_value=0.0):
"""Initilize length regulator module.
Args:
pad_value (float, optional): Value used for padding.
"""
super().__init__()
self.pad_value = pad_value
def forward(self, xs, ds, alpha=1.0):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
ds (LongTensor): Batch of durations of each frame (B, T).
alpha (float, optional): Alpha value to control speed of speech.
Returns:
Tensor: replicated input tensor based on durations (B, T*, D).
"""
if alpha != 1.0:
assert alpha > 0
ds = torch.round(ds.float() * alpha).long()
if ds.sum() == 0:
logging.warning(
"predicted durations includes all 0 sequences. "
"fill the first element with 1."
)
# NOTE(kan-bayashi): This case must not be happened in teacher forcing.
# It will be happened in inference with a bad duration predictor.
# So we do not need to care the padded sequence case here.
ds[ds.sum(dim=1).eq(0)] = 1
repeat = [torch.repeat_interleave(x, d, dim=0) for x, d in zip(xs, ds)]
return pad_list(repeat, self.pad_value)

View File

@ -0,0 +1,844 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation
from einops import pack, rearrange, repeat
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.llm_asr.diffusion_models.transformer import BasicTransformerBlock
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, dim, channel_first=True, padding=1):
super().__init__()
self.channel_first = channel_first
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, padding)
def forward(self, x):
if not self.channel_first:
x = x.transpose(1, 2).contiguous()
out = self.conv(x)
if not self.channel_first:
out = out.transpose(1, 2).contiguous()
return out
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True,
out_channels=None, name="conv", channel_first=True, stride=2):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.channel_first = channel_first
self.stride = stride
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, stride*2, stride, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
if not self.channel_first:
inputs = inputs.transpose(1, 2).contiguous()
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
if not self.channel_first:
outputs = outputs.transpose(1, 2).contiguous()
return outputs
class ConformerWrapper(ConformerBlock):
def __init__( # pylint: disable=useless-super-delegation
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0,
ff_dropout=0,
conv_dropout=0,
conv_causal=False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
**kwargs
):
return super().forward(x=hidden_states, mask=attention_mask.bool())
class TransformerDecoderWrapper(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
dropout: float = 0.1,
activation_fn: str = "snakebeta",
cond_dim: int = 80,
concat_after: bool = False,
):
super().__init__()
attn_dim = num_attention_heads * attention_head_dim
self.input_proj = torch.nn.Linear(dim, attn_dim)
self.cond_proj = torch.nn.Linear(cond_dim, attn_dim)
self.output_proj = torch.nn.Linear(attn_dim, dim)
from funasr.models.transformer.decoder import (
DecoderLayer, MultiHeadedAttention, PositionwiseFeedForward
)
self.decoder_layer = DecoderLayer(
attn_dim,
MultiHeadedAttention(
num_attention_heads, attn_dim, dropout
),
MultiHeadedAttention(
num_attention_heads, attn_dim, dropout
),
PositionwiseFeedForward(attn_dim, attn_dim * 4, dropout),
dropout,
normalize_before=True,
concat_after=concat_after,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
timestep: torch.Tensor=None,
prompt: torch.Tensor = None,
prompt_lengths: torch.Tensor = None,
**kwargs
):
# make masks and forward attention layer
x_mask = attention_mask[:, :1].transpose(1, 2).contiguous()
prompt_mask = (~make_pad_mask(prompt_lengths)[:, :, None]).to(hidden_states.device)
x = self.input_proj(hidden_states) * x_mask
prompt = self.cond_proj(prompt) * prompt_mask
x = self.decoder_layer(x, x_mask.transpose(1, 2).contiguous(), prompt, prompt_mask.transpose(1, 2).contiguous())[0]
x = self.output_proj(x) * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
down_block_type="transformer",
mid_block_type="transformer",
up_block_type="transformer",
conditions: dict = None,
concat_after: bool = False,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.conditions = conditions
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_block_type = down_block_type
self.mid_block_type = mid_block_type
self.up_block_type = up_block_type
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
if conditions is not None and "xvec" in conditions:
input_channel = input_channel + conditions["xvec"]["dim"]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
down_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
concat_after=concat_after,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
if conditions is not None and "xvec" in conditions:
input_channel = input_channel + conditions["xvec"]["dim"]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
mid_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
concat_after=concat_after,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
if conditions is not None and "xvec" in conditions:
input_channel = input_channel + conditions["xvec"]["dim"]
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
self.get_block(
up_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
concat_after=concat_after,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
# nn.init.normal_(self.final_proj.weight)
def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs):
if block_type == "conformer":
block = ConformerWrapper(
dim=dim,
dim_head=attention_head_dim,
heads=num_heads,
ff_mult=1,
conv_expansion_factor=2,
ff_dropout=dropout,
attn_dropout=dropout,
conv_dropout=dropout,
conv_kernel_size=31,
)
elif block_type == "transformer":
block = BasicTransformerBlock(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
elif block_type == "transformer_decoder":
block = TransformerDecoderWrapper(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
cond_dim=self.conditions["prompt"]["dim"],
concat_after=kwargs.get("concat_after", False)
)
else:
raise ValueError(f"Unknown block type {block_type}")
return block
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
if (cond is not None and "xvec" in cond and
self.conditions is not None and "xvec" in self.conditions):
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
x = pack([x, xvec], "b * t")[0]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_down = rearrange(mask_down, "b 1 t -> b t")
if self.down_block_type == "transformer":
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
else:
attn_mask = mask_down.squeeze(1)
for transformer_block in transformer_blocks:
cond_kwargs = {}
if (cond is not None and "prompt" in cond and
self.conditions is not None and "prompt" in self.conditions):
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
**cond_kwargs
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_down = rearrange(mask_down, "b t -> b 1 t")
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
if (cond is not None and "xvec" in cond and
self.conditions is not None and "xvec" in self.conditions):
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
x = pack([x, xvec], "b * t")[0]
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_mid = rearrange(mask_mid, "b 1 t -> b t")
if self.mid_block_type == "transformer":
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
else:
attn_mask = mask_mid.squeeze(1)
for transformer_block in transformer_blocks:
cond_kwargs = {}
if (cond is not None and "prompt" in cond and
self.conditions is not None and "prompt" in self.conditions):
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
**cond_kwargs
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_mid = rearrange(mask_mid, "b t -> b 1 t")
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
if (cond is not None and "xvec" in cond and
self.conditions is not None and "xvec" in self.conditions):
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
x = pack([x, xvec], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_up = rearrange(mask_up, "b 1 t -> b t")
if self.up_block_type == "transformer":
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
else:
attn_mask = mask_up.squeeze(1)
for transformer_block in transformer_blocks:
cond_kwargs = {}
if (cond is not None and "prompt" in cond and
self.conditions is not None and "prompt" in self.conditions):
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
**cond_kwargs
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_up = rearrange(mask_up, "b t -> b 1 t")
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
down_block_type="transformer",
mid_block_type="transformer",
up_block_type="transformer",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_block_type = down_block_type
self.mid_block_type = mid_block_type
self.up_block_type = up_block_type
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
down_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
mid_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
self.get_block(
up_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
# nn.init.normal_(self.final_proj.weight)
def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs):
if block_type == "conformer":
block = ConformerWrapper(
dim=dim,
dim_head=attention_head_dim,
heads=num_heads,
ff_mult=1,
conv_expansion_factor=2,
ff_dropout=dropout,
attn_dropout=dropout,
conv_dropout=dropout,
conv_kernel_size=31,
)
elif block_type == "transformer":
block = BasicTransformerBlock(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
else:
raise ValueError(f"Unknown block type {block_type}")
return block
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_down = rearrange(mask_down, "b 1 t -> b t")
if self.down_block_type == "transformer":
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
else:
attn_mask = mask_down.squeeze(1)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_down = rearrange(mask_down, "b t -> b 1 t")
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_mid = rearrange(mask_mid, "b 1 t -> b t")
if self.mid_block_type == "transformer":
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
else:
attn_mask = mask_mid.squeeze(1)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_mid = rearrange(mask_mid, "b t -> b 1 t")
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# mask_up = rearrange(mask_up, "b 1 t -> b t")
if self.up_block_type == "transformer":
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
else:
attn_mask = mask_up.squeeze(1)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
# mask_up = rearrange(mask_up, "b t -> b 1 t")
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask

View File

@ -0,0 +1,317 @@
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features]
self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
x = self.proj(x)
if self.alpha_logscale:
alpha = torch.exp(self.alpha)
beta = torch.exp(self.beta)
else:
alpha = self.alpha
beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "snakebeta":
act_fn = SnakeBeta(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
**kwargs
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states

View File

@ -0,0 +1,847 @@
import logging
import math
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
import random
from funasr.models.llm_asr.mel_spectrum import (
mel_spectrogram, power_spectrogram, mel_from_power_spectrogram
)
from torch.nn import functional as F
from funasr.models.transformer.utils.nets_utils import pad_list
from distutils.version import LooseVersion
from contextlib import contextmanager
from funasr.utils.hinter import hint_once
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class MelSpectrumExtractor(nn.Module):
def __init__(
self,
n_fft=1024,
num_mels=80,
sampling_rate=22050,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8000,
spec_type="mel",
):
super().__init__()
self.n_fft = n_fft
self.num_mels = num_mels
self.sampling_rate = sampling_rate
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.spec_type = spec_type
def extra_repr(self):
return f"n_fft={self.n_fft}, num_mels={self.num_mels}, sampling_rate={self.sampling_rate}, " \
f"hop_size={self.hop_size}, win_size={self.win_size}, fmin={self.fmin}, fmax={self.fmax}"
def forward(self, x, ilens):
if self.spec_type == "power":
feat = power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
self.hop_size, self.win_size, self.fmin, self.fmax)
else:
feat = mel_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
self.hop_size, self.win_size, self.fmin, self.fmax)
# determine olens by compare the lengths of inputs and outputs
olens = ilens // (x.shape[1] // feat.shape[2])
return feat.transpose(1, 2), olens
def convert_power_to_mel(self, x, ilens):
feat = mel_from_power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
self.hop_size, self.win_size, self.fmin, self.fmax)
return feat, ilens
class QuantizerCodebook(torch.nn.Module):
def __init__(
self,
num_quantizers,
codebook_size,
codebook_dim,
hop_length,
sampling_rate
):
super().__init__()
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.hop_size = hop_length
self.sampling_rate = sampling_rate
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
self.register_buffer("embed", embed)
codec_index_shift = 1024 * torch.arange(32, dtype=torch.float32)[None, None, :]
self.register_buffer("codec_index_shift", codec_index_shift)
def save_embedding(self, file_name, dense_emb, emb_lengths):
import kaldiio
wav_writer = kaldiio.WriteHelper("ark,scp,f:{}.ark,{}.scp".format(file_name, file_name))
dense_emb = dense_emb.cpu().numpy()
for i in range(min(dense_emb.shape[0], 10)):
wav_writer(str(i), dense_emb[i, :emb_lengths[i]])
wav_writer.close()
def forward(self, codec: torch.Tensor, codec_lengths, return_subs=False):
if len(codec.shape) == 2:
codec = codec.unsqueeze(-1)
bz, tt, nq = codec.shape[0], codec.shape[1], codec.shape[2]
codec_mask = ~make_pad_mask(codec_lengths, maxlen=codec.shape[1]).unsqueeze(-1).to(codec.device)
codec = codec * codec_mask + self.codec_index_shift[:, :, :nq].long()
codec = codec.reshape(-1, nq)
emb = self.embed.reshape(-1, self.codebook_dim)
codec_emb = F.embedding(codec, emb) # (BT, Nq, D)
dense_emb = codec_emb.sum(dim=1)
dense_emb = dense_emb.reshape(bz, tt, self.codebook_dim)
if return_subs:
sub_embs = codec_emb.reshape(bz, tt, nq, self.codebook_dim) * codec_mask.unsqueeze(-2)
return (dense_emb * codec_mask, sub_embs), codec_lengths
return dense_emb * codec_mask, codec_lengths
class BaseDiffWithXvec(nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 80,
xvec_size: int = 198,
output_type: str = "mel",
encoder_conf: Dict = {},
decoder_conf: Dict = {},
mel_feat_conf: Dict = {},
codec_conf: Dict = {},
length_regulator_conf: Dict = None,
prompt_conf: Dict = None,
vocab_size: int = None,
token_list: List = None,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.encoder_conf = encoder_conf
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.token_list = token_list
self.output_type = output_type
self.prompt_conf = prompt_conf
self.input_frame_rate = kwargs.get("input_frame_rate", 50)
logging.info(f"input frame rate={self.input_frame_rate}")
if output_type == 'mel':
self.mel_extractor = MelSpectrumExtractor(**mel_feat_conf)
elif output_type == 'codec':
num_quantizers = codec_conf.get("num_quantizers", 32)
codebook_size = codec_conf.get("codebook_size", 1024)
codebook_dim = codec_conf.get("codebook_dim", 128)
hop_length = codec_conf.get("hop_length", 640)
sampling_rate = codec_conf.get("sampling_rate", 16000)
self.quantizer_codebook = QuantizerCodebook(num_quantizers, codebook_size, codebook_dim,
hop_length, sampling_rate)
if vocab_size is not None and vocab_size > 0:
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.xvec_proj = torch.nn.Linear(xvec_size, output_size)
self.encoder = self.build_encoder()
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = self.build_decoder()
self.length_regulator_conf = length_regulator_conf
self.length_regulator = self.build_length_regulator()
def build_encoder(self):
encoder_name = self.encoder_conf.pop("name", "transformer")
model = None
if encoder_name == "transformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**self.encoder_conf,
input_size=self.input_size,
use_cnn_module=False,
macaron_style=False,
)
elif encoder_name == "conformer":
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
model = ConformerEncoder(
**self.encoder_conf,
input_size=self.input_size,
)
self.encoder_conf["name"] = encoder_name
return model
def build_decoder(self):
decoder_name = self.decoder_conf.pop("name", "transformer")
model = None
if decoder_name == "matcha":
from funasr.models.llm_asr.diffusion_models.flow_matching import CFM
model = CFM(
**self.decoder_conf,
in_channels=self.output_size * 2, # 2 for noise_y and mu
out_channel=self.output_size,
spk_emb_dim=self.output_size
)
self.decoder_conf["name"] = decoder_name
return model
def select_target_prompt(self, y, y_lengths):
prompt_conf = self.prompt_conf
prompt_list = []
prompt_lengths = []
for i, y_len in enumerate(y_lengths):
prompt_len = random.randint(
int(y_len * prompt_conf["prompt_with_range_ratio"][0]),
int(y_len * prompt_conf["prompt_with_range_ratio"][1])
)
prompt_pos = random.randint(0, y_len - prompt_len)
prompt_list.append(y[i, prompt_pos:prompt_pos+prompt_len])
prompt_lengths.append(prompt_len)
prompt = pad_list(prompt_list, 0.0)
prompt_lengths = torch.tensor(prompt_lengths, dtype=torch.int64, device=y.device)
if "cgf_prob" in prompt_conf and prompt_conf["cgf_prob"] > 0:
cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device) < prompt_conf["cgf_prob"]
prompt = prompt * cgf_mask
return prompt, prompt_lengths
def build_length_regulator(self):
name = self.length_regulator_conf.pop("name", None)
model = None
if name == "upsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator
model = UpSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "downsampling":
from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator
model = DownSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios"))
elif name == "interpolate":
from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator
model = InterpolateRegulator(self.output_size, **self.length_regulator_conf)
else:
raise ValueError(f"Unknown length_regulator {name}")
self.length_regulator_conf["name"] = name
return model
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
audio: torch.Tensor,
audio_lengths: torch.Tensor,
xvec: torch.Tensor,
xvec_lengths: torch.Tensor,
):
batch_size = audio.shape[0]
# for data parallel
x = text[:, :text_lengths.max()]
y = audio[:, :audio_lengths.max()]
xvec = xvec[:, :xvec_lengths.max()]
if self.vocab_size is not None and self.vocab_size > 0:
mask = (x != -1).float().unsqueeze(-1)
x = self.input_embedding(torch.clamp(x, min=0)) * mask
# random select a xvec from xvec matrix
xvec_list = []
for i, ilen in enumerate(xvec_lengths):
idx = random.randint(0, ilen-1)
while torch.any(~torch.isfinite(xvec[i, idx])):
idx = random.randint(0, ilen - 1)
xvec_list.append(xvec[i, idx])
rand_xvec = torch.vstack(xvec_list)
rand_xvec = self.xvec_proj(rand_xvec)
y, y_lengths = self.extract_feat(y, audio_lengths)
h, h_lengths, _ = self.encoder(x, text_lengths)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
if self.prompt_conf is not None:
target_prompt = self.select_target_prompt(y, y_lengths)
conditions = dict(
xvec=rand_xvec,
target_prompt=target_prompt,
)
else:
conditions = None
mask = (~make_pad_mask(y_lengths)).to(y)
# y, h in (B, T, D)
loss, _ = self.decoder.compute_loss(
y.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
rand_xvec,
cond=conditions
)
stats = dict(loss=torch.clone(loss.detach()))
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@torch.no_grad()
def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor):
if self.output_type == 'mel':
return self.mel_extractor(y, y_lengths)
elif self.output_type == "codec":
return self.quantizer_codebook(y.long(), y_lengths)
else:
return y, y_lengths
@torch.no_grad()
def inference(self, text, text_lens, xvec, xvec_lens, diff_steps=10, temperature=1.0, prompt=None):
avg_xvec = torch.mean(xvec, dim=1)
avg_xvec = self.xvec_proj(avg_xvec)
if self.vocab_size is not None and self.vocab_size > 0:
mask = (text != -1).float().unsqueeze(-1)
text = self.input_embedding(torch.clamp(text, min=0)) * mask
h, h_lengths, _ = self.encoder(text, text_lens)
h = self.encoder_proj(h)
if self.output_type == "mel":
coeff = ((self.mel_extractor.sampling_rate / self.mel_extractor.hop_size) /
self.input_frame_rate)
else:
coeff = ((self.quantizer_codebook.sampling_rate / self.quantizer_codebook.hop_size) /
self.input_frame_rate)
y = torch.zeros([1, int(h.shape[1] * coeff), 80], device=text.device)
y_lens = (text_lens * coeff).long()
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
mask = (~make_pad_mask(y_lens)).to(y)
feat = self.decoder.forward(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
n_timesteps=diff_steps,
temperature=temperature,
spks=avg_xvec,
cond=None,
)
return feat
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
pass
class MaskedDiffWithXvec(BaseDiffWithXvec):
def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel",
encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {},
length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None,
token_list: List = None, **kwargs):
super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf,
codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs)
if self.prompt_conf is not None:
self.masker = self.build_masker()
self.cgf_prob = prompt_conf.get("cgf_prob", 0.0)
self.prompt_dropout_rate = prompt_conf.get("prompt_dropout", 0.0)
if self.prompt_dropout_rate > 0:
self.prompt_dropout = nn.Dropout(self.prompt_dropout_rate)
else:
self.prompt_dropout = None
self.only_mask_loss = kwargs.get("only_mask_loss", False)
self.io_ratio = kwargs.get("io_ratio", None)
if self.io_ratio == "auto":
self.io_ratio = mel_feat_conf["sampling_rate"] / mel_feat_conf["hop_size"] / self.input_frame_rate
self.first_package_conf = kwargs.get("first_package_conf", None)
self.length_normalizer_ratio = kwargs.get("length_normalizer_ratio", None)
def build_masker(self):
prompt_type = self.prompt_conf.get("prompt_type", "free")
if prompt_type == "prefix":
from funasr.models.specaug.mask_along_axis import PrefixMaskVariableMaxWidth
masker = PrefixMaskVariableMaxWidth(
mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
)
else:
from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
masker = MaskAlongAxisVariableMaxWidth(
mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
num_mask=1,
)
return masker
@staticmethod
def norm_and_sample_xvec(xvec, xvec_lengths):
xvec_list = []
for i, ilen in enumerate(xvec_lengths):
if ilen == 1:
idx = 0
else:
idx = random.randint(0, ilen - 1)
while torch.any(~torch.isfinite(xvec[i, idx])):
idx = random.randint(0, ilen - 1)
if torch.any(~torch.isfinite(xvec[i, idx])):
to_add = torch.zeros_like(xvec[i, idx])
else:
to_add = xvec[i, idx]
xvec_list.append(to_add)
rand_xvec = torch.vstack(xvec_list)
rand_xvec = F.normalize(rand_xvec, dim=1)
return rand_xvec
def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor):
_, _, cond_mask = self.masker(y, y_lengths, return_mask=True)
cond_mask = ~cond_mask
if self.cgf_prob > 0:
cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device)
cond_mask = cond_mask * (cgf_mask > self.cgf_prob)
return cond_mask
def build_decoder(self):
decoder_name = self.decoder_conf.pop("name", "transformer")
model = None
if decoder_name == "matcha":
from funasr.models.llm_asr.diffusion_models.flow_matching import CFM
model = CFM(
**self.decoder_conf,
)
self.decoder_conf["name"] = decoder_name
return model
def sample_first_package(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
audio: torch.Tensor,
audio_lengths: torch.Tensor,
):
sample_rate = self.first_package_conf["sample_rate"]
min_token_len, max_token_len = self.first_package_conf["token_len_range"]
random_start = self.first_package_conf.get("random_start", False)
bs = text.shape[0]
sample_mask = torch.rand((bs, ), device=text_lengths.device) < sample_rate
if random_start:
text_list, text_lengths_list, audio_list, audio_lengths_list = [], [], [], []
for i, total_len in enumerate(text_lengths):
total_len = total_len.item()
if sample_mask[i].item():
if isinstance(min_token_len, float) and 0.0 < min_token_len <= 1.0:
min_token_len = math.floor(min_token_len * total_len)
if isinstance(max_token_len, float) and 0.0 < max_token_len <= 1.0:
max_token_len = math.floor(max_token_len * total_len)
if total_len > max_token_len > min_token_len:
fp_len = random.randint(min_token_len, max_token_len)
start = random.randint(0, total_len - fp_len)
audio_st, audio_len = self.calc_target_len(torch.tensor(start)), self.calc_target_len(torch.tensor(fp_len))
else:
start, fp_len = 0, total_len
audio_st, audio_len = 0, self.calc_target_len(fp_len)
text_list.append(text[i, start: start+fp_len])
text_lengths_list.append(fp_len)
audio_list.append(audio[i, audio_st: audio_st+audio_len])
audio_lengths_list.append(audio_list[-1].shape[0])
else:
text_list.append(text[i])
text_lengths_list.append(text_lengths[i])
audio_list.append(audio[i, :min(self.calc_target_len(text_lengths[i]), audio_lengths[i])])
audio_lengths_list.append(audio_list[-1].shape[0])
text = pad_list(text_list, pad_value=0.0).to(text)
new_text_lengths = torch.tensor(text_lengths_list, dtype=torch.int64, device=text.device)
audio = pad_list(audio_list, pad_value=0.0).to(audio)
new_audio_lengths = torch.tensor(audio_lengths_list, dtype=torch.int64, device=audio.device)
else:
fp_token_len = torch.randint(min_token_len, max_token_len + 1, (bs,))
fp_token_len = torch.minimum(fp_token_len.to(text_lengths), text_lengths)
fp_audio_len = self.calc_target_len(fp_token_len)
fp_audio_len = torch.minimum(fp_audio_len.to(audio_lengths), audio_lengths)
new_text_lengths = torch.where(sample_mask, fp_token_len, text_lengths)
new_audio_lengths = torch.where(sample_mask, fp_audio_len, audio_lengths)
text = text * (~make_pad_mask(new_text_lengths, maxlen=text.shape[1]).unsqueeze(-1)).to(text.device)
audio = audio * (~make_pad_mask(new_audio_lengths, maxlen=audio.shape[1]).unsqueeze(-1)).to(audio.device)
return text, new_text_lengths, audio, new_audio_lengths
@staticmethod
def clip_both_side(y, y_lengths, raw_lengths):
res_list = []
new_length = []
for i, (new_len, org_len) in enumerate(zip(y_lengths, raw_lengths)):
if new_len >= org_len:
res_list.append(y[i, :new_len])
else:
left = (org_len - new_len) // 2
right = org_len - new_len - left
res_list.append(y[i, left: org_len-right])
new_length.append(res_list[-1].shape[0])
new_length = torch.tensor(new_length).to(y_lengths)
return pad_list(res_list, 0.0), new_length
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
audio: torch.Tensor,
audio_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None,
xvec_lengths: Optional[torch.Tensor] = None,
):
batch_size = audio.shape[0]
# for data parallel
with autocast(False):
x = text[:, :text_lengths.max()]
y = audio[:, :audio_lengths.max()]
if self.vocab_size is not None and self.vocab_size > 0:
mask = (x != -1).float().unsqueeze(-1)
x = self.input_embedding(torch.clamp(x, min=0)) * mask
# random select a xvec from xvec matrix
rand_xvec = None
if xvec is not None:
xvec = xvec[:, :xvec_lengths.max()]
rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
rand_xvec = self.xvec_proj(rand_xvec)
y, y_lengths = self.extract_feat(y, audio_lengths)
if self.length_normalizer_ratio is not None:
max_y_lengths = torch.round(text_lengths * self.length_normalizer_ratio).long()
raw_lengths = y_lengths.clone()
y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths)
y, new_y_lengths = self.clip_both_side(y, y_lengths, raw_lengths)
logging.info(f"normalized y_length from {raw_lengths.cpu().tolist()} to {y_lengths.cpu().tolist()} "
f"new_y_length {new_y_lengths.cpu().tolist()}, with text_lengths {text_lengths.cpu().tolist()}")
y = y[:, :new_y_lengths.max()]
elif self.io_ratio is not None:
hint_once(f"cut output with ratio {self.io_ratio}", "print_ratio", rank=0)
max_y_lengths = (text_lengths * self.io_ratio + 3).long()
if y_lengths.max() > max_y_lengths.max():
logging.info(f"cut output with ratio {self.io_ratio} from {y_lengths.max()} to {max_y_lengths.max()}")
y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths)
y = y[:, :y_lengths.max()]
if self.first_package_conf is not None:
x, text_lengths, y, y_lengths = self.sample_first_package(
x, text_lengths, y, y_lengths
)
x = x[:, :text_lengths.max()]
y = y[:, :y_lengths.max()]
h, _, _ = self.encoder(x, text_lengths)
h_lengths = text_lengths
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
if self.prompt_conf is not None:
cond_mask = self.select_target_prompt(y, y_lengths)
if self.prompt_dropout is not None:
hint_once(f"prompt dropout {self.prompt_dropout_rate}", "prompt dropout")
y = self.prompt_dropout(y)
conditions = (y * cond_mask).transpose(1, 2)
else:
cond_mask, conditions = None, None
stats = dict(
batch_size=batch_size,
in_lengths=text_lengths.max(),
out_lengths=y_lengths.max(),
)
mask = (~make_pad_mask(y_lengths)).to(y)
# y, h in (B, T, D)
loss, _ = self.decoder.compute_loss(
y.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
rand_xvec,
cond=conditions,
reduction="none",
)
loss = loss.transpose(1, 2)
all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1])
if cond_mask is not None:
masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask)
else:
masked_loss_mask = mask.unsqueeze(-1)
masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1])
stats["all_loss"] = all_loss.item()
stats["masked_loss"] = masked_loss.item()
loss = masked_loss if self.only_mask_loss else all_loss
stats["loss"] = loss.item()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@staticmethod
def concat_prompt(prompt, prompt_lengths, text, text_lengths):
xs_list, x_len_list = [], []
for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)):
xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0))
x_len_list.append(_prompt_len + _text_len)
xs = pad_list(xs_list, pad_value=0.0)
x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device)
return xs, x_lens
@staticmethod
def remove_prompt(prompt, prompt_lengths, padded, padded_lengths):
xs_list = []
for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)):
xs_list.append(padded[idx, _prompt_len: _x_len])
xs = pad_list(xs_list, pad_value=0.0)
return xs, padded_lengths - prompt_lengths
@staticmethod
def norm_and_avg_xvec(xvec: torch.Tensor, xvec_lens: torch.Tensor):
mask = torch.isfinite(xvec.norm(dim=-1, keepdim=True))
norm_xvec = F.normalize(xvec, dim=-1) * mask
avg_xvec = F.normalize(torch.sum(norm_xvec, dim=1) / mask.sum(), dim=-1)
return avg_xvec
def calc_target_len(self, in_len):
if self.input_frame_rate == 25 and self.output_type == "mel":
if self.length_normalizer_ratio is not None:
if isinstance(in_len, int):
in_len = torch.tensor(in_len)
ll = torch.round(in_len * self.length_normalizer_ratio)
else:
ll = (in_len * 4 + 4) * 160 + 400
ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size
if isinstance(in_len, int):
ll = int(round(ll))
else:
ll = torch.round(ll).long()
return ll
if self.input_frame_rate == 50 and self.output_type == "mel":
if self.length_normalizer_ratio is not None:
if isinstance(in_len, int):
in_len = torch.tensor(in_len)
ll = torch.round(in_len * self.length_normalizer_ratio)
else:
ll = (in_len * 2 + 2) * 160 + 400
ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size
if isinstance(in_len, int):
ll = int(round(ll))
else:
ll = torch.round(ll).long()
return ll
elif self.output_type == "codec":
return in_len
else:
raise ValueError(f"Frame rate {self.input_frame_rate} has not implemented.")
@torch.no_grad()
def inference(self, text, text_lens,
xvec=None, xvec_lens=None,
diff_steps=10, temperature=1.0, prompt: dict = None, y_lens=None):
rand_xvec = None
if xvec is not None:
if xvec.dim() == 2:
xvec = xvec.unsqueeze(1)
xvec_lens = torch.ones_like(xvec_lens)
rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens)
rand_xvec = self.xvec_proj(rand_xvec)
prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None))
prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None))
if self.vocab_size is not None and self.vocab_size > 0:
if prompt_text is not None:
text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens)
mask = (text != -1).float().unsqueeze(-1)
text = self.input_embedding(torch.clamp(text, min=0)) * mask
h, h_lengths, _ = self.encoder(text, text_lens)
h = self.encoder_proj(h)
if y_lens is None:
y_lens = self.calc_target_len(text_lens)
y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device)
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
# get conditions
if prompt_audio is not None:
if prompt_audio.ndim == 2:
prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens)
for i, _len in enumerate(prompt_audio_lens):
y[i, :_len] = prompt_audio[i]
conds = y.transpose(1, 2)
mask = (~make_pad_mask(y_lens)).to(y)
feat = self.decoder.forward(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
n_timesteps=diff_steps,
temperature=temperature,
spks=rand_xvec,
cond=conds,
)
if prompt_text is not None and prompt_audio is not None:
feat = feat.transpose(1, 2)
feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device)
feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens)
feat = feat.transpose(1, 2)
# if prompt_audio is not None:
# feat_rmq = torch.sqrt(torch.mean(torch.pow(feat, 2), dim=[1, 2], keepdim=True))
# prompt_rmq = torch.sqrt(torch.mean(torch.pow(prompt_audio, 2), dim=[1, 2], keepdim=True))
# feat = feat / feat_rmq * prompt_rmq
return feat
class MaskedDiffTTS(MaskedDiffWithXvec):
def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel",
encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {},
length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None,
token_list: List = None, **kwargs):
super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf,
codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs)
self.length_loss_weight = kwargs.get("length_loss_weight", 0.0)
if self.length_loss_weight > 0.0:
self.length_predictor = nn.Linear(self.encoder.output_size(), 1)
def calc_target_len(self, enc_outs, enc_lens):
text_durs = self.length_predictor(enc_outs)
text_durs = torch.exp(text_durs)
mask = ~make_pad_mask(enc_lens, xs=text_durs)
utt_durs = (text_durs * mask).sum(dim=1).squeeze(-1)
return utt_durs
def forward(self, text: torch.Tensor, text_lengths: torch.Tensor, audio: torch.Tensor, audio_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None):
batch_size = audio.shape[0]
# for data parallel
x = text[:, :text_lengths.max()]
y = audio[:, :audio_lengths.max()]
if self.vocab_size is not None and self.vocab_size > 0:
mask = (x != -1).float().unsqueeze(-1)
x = self.input_embedding(torch.clamp(x, min=0)) * mask
# random select a xvec from xvec matrix
rand_xvec = None
if xvec is not None:
xvec = xvec[:, :xvec_lengths.max()]
rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
rand_xvec = self.xvec_proj(rand_xvec)
y, y_lengths = self.extract_feat(y, audio_lengths)
h, _, _ = self.encoder(x, text_lengths)
h_lengths = text_lengths
h_durs = self.calc_target_len(h, h_lengths)
utt_dur_loss = self.length_loss_weight * F.l1_loss(h_durs, y_lengths)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
if self.prompt_conf is not None:
cond_mask = self.select_target_prompt(y, y_lengths)
conditions = (y * cond_mask).transpose(1, 2)
else:
cond_mask, conditions = None, None
stats = dict(
batch_size=batch_size,
in_lengths=text_lengths.max(),
out_lengths=y_lengths.max(),
)
mask = (~make_pad_mask(y_lengths)).to(y)
# y, h in (B, T, D)
loss, _ = self.decoder.compute_loss(
y.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
rand_xvec,
cond=conditions,
reduction="none",
)
loss = loss.transpose(1, 2)
all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1])
if cond_mask is not None:
masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask)
else:
masked_loss_mask = mask.unsqueeze(-1)
masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1])
stats["all_loss"] = all_loss.item()
stats["masked_loss"] = masked_loss.item()
loss = masked_loss if self.only_mask_loss else all_loss
stats["mel_loss"] = loss.item()
loss = loss + utt_dur_loss
stats["loss"] = loss.item()
stats["utt_dur_loss"] = utt_dur_loss.item()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def inference(self, text, text_lens, xvec=None, xvec_lens=None, diff_steps=10, temperature=1.0,
prompt: dict = None):
rand_xvec = None
if xvec is not None:
if xvec.dim() == 2:
xvec = xvec.unsqueeze(1)
xvec_lens = torch.ones_like(xvec_lens)
rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens)
rand_xvec = self.xvec_proj(rand_xvec)
prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None))
prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None))
if self.vocab_size is not None and self.vocab_size > 0:
if prompt_text is not None:
text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens)
mask = (text != -1).float().unsqueeze(-1)
text = self.input_embedding(torch.clamp(text, min=0)) * mask
h, _, _ = self.encoder(text, text_lens)
h_lengths = text_lens
y_lens = self.calc_target_len(h, h_lengths).round().long()
y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
# get conditions
if prompt_audio is not None:
if prompt_audio.ndim == 2:
prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens)
for i, _len in enumerate(prompt_audio_lens):
y[i, :_len] = prompt_audio[i]
conds = y.transpose(1, 2)
mask = (~make_pad_mask(y_lens)).to(y)
feat = self.decoder.forward(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
n_timesteps=diff_steps,
temperature=temperature,
spks=rand_xvec,
cond=conds,
)
if prompt_text is not None and prompt_audio is not None:
feat = feat.transpose(1, 2)
feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device)
feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens)
feat = feat.transpose(1, 2)
return feat

View File

@ -0,0 +1,477 @@
# Copyright 2023 KaiHu
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HIFI-GAN"""
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple, List, Union
import typing as tp
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.train_utils.device_funcs import force_gatherable
from librosa.filters import mel as librosa_mel_fn
import logging
from funasr.utils.hinter import hint_once
class Audio2Mel(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
sampling_rate=22050,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None,
center=False,
device='cuda',
feat_type="power_log",
):
super().__init__()
##############################################
# FFT Parameters #
##############################################
window = torch.hann_window(win_length, device=device).float()
mel_basis = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float().to(device)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.center = center
self.feat_type = feat_type
def forward(self, audioin):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
)
if self.feat_type == "mag_log10":
power_spec = torch.sqrt(torch.sum(torch.pow(fft, 2), dim=[-1]))
mel_output = torch.matmul(self.mel_basis, power_spec)
return torch.log10(torch.clamp(mel_output, min=1e-5))
power_spec = torch.sum(torch.pow(fft, 2), dim=[-1])
mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9))
return self.spectral_normalize(mel_spec)
@classmethod
def spectral_normalize(cls, spec, C=1, clip_val=1e-5):
output = cls.dynamic_range_compression(spec, C, clip_val)
return output
@classmethod
def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5):
output = cls.dynamic_range_decompression(spec, C, clip_val)
return output
@staticmethod
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
@staticmethod
def dynamic_range_decompression(x, C=1):
return torch.exp(x) / C
class HifiGan(nn.Module):
"""HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator).
NSF-HIFIGAN, HiFTNet Optional.
"""
def __init__(
self,
input_size: int,
frontend: torch.nn.Module = None,
nsf_augmented: bool = False,
f0_predictor: dict = None,
generator: dict = None,
discriminator: dict = None,
target_sample_hz: int = 22_050,
multi_mel_spectral_window: Union[Tuple, List] = tuple([1024]),
multi_mel_spectral_hop: Union[Tuple, List] = tuple([256]),
multi_mel_spectral_fft: Union[Tuple, List] = tuple([1024]),
multi_mel_spectral_n_mels: Union[Tuple, List] = tuple([80]),
mel_fmin: float = 0,
mel_fmax: float = 8000,
mel_fmax_for_loss: Optional[float] = None,
multi_mel_spectral_recon_loss_weight: Union[Tuple[float], List[float]] = tuple([45]),
adversarial_loss_weight: float = 1.0,
feat_match_loss_weight: float = 2.0,
tpr_loss_params: tp.Dict[str, tp.Any] = {"weight": 0.0, "tau": 0.04},
mel_feat_type="power_log",
):
"""Initialize HifiGan model.
Args:
f0_predictor: f0 predictor (pretrained && frozen) for NSF-HIFIGAN, Optional.
generator: hifigan generator
discriminator: several discriminators, such as MSD, MPD, MRD
multi_mel_spectral_window: stft window length
multi_mel_spectral_hop: stft hop length
multi_mel_spectral_fft: fft bins
multi_mel_spectral_n_mels: Mel frequency bins
mel_fmin: fmin for mel
mel_fmax: fmax for mel
mel_fmax_for_loss: fmax for multi mel spectral loss
multi_mel_spectral_recon_loss_weight: the weight of frequency-domain reconstruction loss
adversarial_loss_weight: the weight of adversarial loss from discriminator
feat_match_loss_weight: the weight of intermediate feature loss from discriminator
tpr_loss_params: the weight and tau of Truncated Pointwise Relativistic (TPR) loss from discriminator.
"""
super().__init__()
self.decoder = self.build_decoder(generator)
# Used by task and trainer
self.gen_model_list = [self.decoder]
# nsf-hifigan or original hifigan
self.nsf_augmented = nsf_augmented
if nsf_augmented:
assert f0_predictor is not None
self.f0_predictor = self.build_f0_predictor(f0_predictor)
# frozen
for param in self.f0_predictor.parameters():
param.requires_grad = False
self.gen_model_list.append(self.f0_predictor)
self.discriminator = self.build_discriminator(discriminator)
self.multi_mel_spec_transforms = nn.ModuleList()
for n_fft, hop_len, win_len, n_mel in zip(multi_mel_spectral_fft, multi_mel_spectral_hop,
multi_mel_spectral_window, multi_mel_spectral_n_mels):
self.multi_mel_spec_transforms.append(
Audio2Mel(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
sampling_rate=target_sample_hz,
n_mel_channels=n_mel,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax_for_loss,
center=False,
)
)
self.mel_spec_transform = Audio2Mel(
n_fft=multi_mel_spectral_fft[0],
hop_length=multi_mel_spectral_hop[0],
win_length=multi_mel_spectral_window[0],
sampling_rate=target_sample_hz,
n_mel_channels=multi_mel_spectral_n_mels[0],
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
center=False,
feat_type=mel_feat_type,
)
# loss weights
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
self.adversarial_loss_weight = adversarial_loss_weight
self.feat_match_loss_weight = feat_match_loss_weight
self.tpr_loss_weight = tpr_loss_params.get("weight", 0.0)
self.tpr_loss_tau = tpr_loss_params.get("tau", 0.04)
self.register_buffer('zero', torch.tensor([0.]), persistent=False)
self.gen_loss = 0
self.sample_rate = target_sample_hz
self.forward_step = 0
def build_decoder(self, conf):
from funasr.models.llm_asr.hifigan_module.generator import HiFTGenerator
return HiFTGenerator(**conf)
def build_f0_predictor(self, conf):
from funasr.models.llm_asr.hifigan_module.nsf_utils import ConvRNNF0Predictor
return ConvRNNF0Predictor(**conf)
def build_discriminator(self, conf):
from funasr.models.llm_asr.hifigan_module.discriminator import MultipleDiscriminator
return MultipleDiscriminator(**conf)
@property
def generator(self):
return torch.nn.ModuleList(self.gen_model_list)
def forward(
self,
forward_generator: bool = True,
batch: Dict = None,
) -> Dict[str, Any]:
"""Forward functions of generator and discriminator.
Args:
forward_generator (bool): Whether to forward generator.
batch (Dict[str, Tensor]): one batch including:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
if forward_generator:
if self.training:
self.forward_step += 1
return self._forward_generator(
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
)
else:
return self._forward_discriminator(
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
"""Given a tensor `x`, returns the encoded representation for `x`
"""
assert x.dim() == 3
_, channel, length = x.size()
assert channel == 1
mel = self.mel_spec_transform(x)
return mel.squeeze()
@torch.no_grad()
def _f0_pred(self, x: torch.Tensor) -> torch.Tensor:
"""Given a tensor `x`, return the predicted f0 for `x`, x in (B, C, T)
"""
if self.nsf_augmented:
f0 = self.f0_predictor(x)
if len(f0.shape) == 1:
f0 = f0.unsqueeze(0)
return f0
else:
return torch.zeros_like(x)
def _decode(self, x: torch.Tensor, g: Union[torch.Tensor] = None) -> torch.Tensor:
"""Decode the given representation into a waveform.
Args:
x (Tensor): Speech representation tensor (B, C1, T)
g (Tensor): Global conditional vector (B, C2, 1).
"""
if self.nsf_augmented:
f0 = self._f0_pred(x)
return self.decoder(x, f0, g)
else:
return self.decoder(x, g)
def _forward_generator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Dict[str, Any]:
"""Perform generator forward.
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
batch_size = speech.size(0)
speech = speech.unsqueeze(1)
orig_speech = speech.clone()
mel = self._encode(speech) # [B, C, T]
recon_speech = self._decode(mel)[:, :, :speech.shape[-1]]
# L1 Mel-Spectrogram Loss
multi_mel_recon_loss = self.zero
for lamda, mel_transform in zip(self.multi_mel_spectral_recon_loss_weight, self.multi_mel_spec_transforms):
orig_mel, recon_mel = map(mel_transform, (orig_speech, recon_speech))
multi_mel_recon_loss = multi_mel_recon_loss + lamda * F.l1_loss(orig_mel, recon_mel)
# calculate discriminator outputs
# disc_outputs in the format [disc1_outputs, disc2_outputs, ...]
# disc1_outputs includes [logits, intermediates]
# intermediates includes [layer_1_intermediate, layer_2_intermediate, ...]
fake_disc_outputs = self.discriminator(recon_speech)
with torch.no_grad():
# do not store discriminator gradient in generator turn
real_disc_outputs = self.discriminator(orig_speech)
# calculate discriminator loss including adversarial, feat matching losses and tpr losses [Optional]
adversarial_losses = []
disc_feature_losses = []
tpr_losses = []
for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs):
real_logits, real_intermediates = real_output
fake_logits, fake_intermediates = fake_output
adversarial_losses.append(torch.mean((1 - fake_logits)**2))
for real_inter, fake_inter in zip(real_intermediates, fake_intermediates):
_loss = torch.mean(torch.abs(real_inter.detach() - fake_inter))
disc_feature_losses.append(_loss)
if self.tpr_loss_weight > 0.0:
tau = self.tpr_loss_tau
m_DG = torch.median((fake_logits - real_logits))
L_rel = torch.mean((((fake_logits - real_logits) - m_DG) ** 2)[fake_logits < real_logits + m_DG])
tpr_losses.append(tau - F.relu(tau - L_rel))
adversarial_loss = torch.stack(adversarial_losses).sum()
feat_match_loss = torch.stack(disc_feature_losses).sum()
tpr_loss = torch.zeros_like(adversarial_loss)
if len(tpr_losses) > 0:
tpr_loss = torch.stack(tpr_losses).sum()
# calculate losses
gen_loss = multi_mel_recon_loss + \
adversarial_loss * self.adversarial_loss_weight + \
feat_match_loss * self.feat_match_loss_weight + \
tpr_loss * self.tpr_loss_weight
self.gen_loss += gen_loss.item()
loss = gen_loss
stats = dict(
generator_loss=loss.item(),
generator_multi_mel_recon_loss=multi_mel_recon_loss.item(),
generator_adv_loss=adversarial_loss.item(),
generator_feat_match_loss=feat_match_loss.item(),
generator_tpr_loss=tpr_loss.item(),
batch_size=batch_size,
batch_length=speech.shape[2],
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return {
"loss": loss,
"stats": stats,
"weight": weight,
"optim_idx": 0, # needed for trainer
"real": orig_speech,
"fake": recon_speech,
}
def _forward_discriminator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Dict[str, Any]:
"""Perform discriminator forward.
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
batch_size = speech.size(0)
speech = speech.unsqueeze(1)
orig_speech = speech.clone()
# A: calculate generator outputs
with torch.no_grad():
# do not store generator gradient in discriminator turn
mel = self._encode(speech) # [B, C, T]
recon_speech = self._decode(mel)[:, :, :speech.shape[-1]]
# B: calculate discriminator outputs
real, fake = orig_speech.clone(), recon_speech.detach()
real_disc_outputs = self.discriminator(real)
fake_disc_outputs = self.discriminator(fake)
# C: calculate discriminator losses, tpr losses [Optional]
disc_losses = []
tpr_losses = []
for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs):
real_logits, real_intermediates = real_output
fake_logits, fake_intermediates = fake_output
one_disc_loss = torch.mean((1-real_logits) ** 2) + torch.mean((0 - fake_logits) ** 2)
disc_losses.append(one_disc_loss)
if self.tpr_loss_weight > 0.0:
tau = self.tpr_loss_tau
m_DG = torch.median((real_logits - fake_logits))
L_rel = torch.mean((((real_logits - fake_logits) - m_DG) ** 2)[real_logits < fake_logits + m_DG])
tpr_losses.append(tau - F.relu(tau - L_rel))
disc_loss = torch.stack(disc_losses).sum()
tpr_loss = torch.zeros_like(disc_loss)
if len(tpr_losses) > 0:
tpr_loss = torch.stack(tpr_losses).sum()
self.gen_loss = 0
loss = disc_loss + self.tpr_loss_weight * tpr_loss
stats = dict(
discriminator_total_loss=loss.item(),
discriminator_loss=disc_loss.item(),
discriminator_tpr_loss=tpr_loss.item(),
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return {
"loss": loss,
"stats": stats,
"weight": weight,
"optim_idx": 1, # needed for trainer
"real": orig_speech,
"fake": recon_speech,
}
def inference(
self,
x: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Run inference.
Args:
x (torch.Tensor): input representation, B x T x C
Returns:
Dict[str, Tensor]:
* recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav).
"""
recon_speech = self._decode(x.transpose(1, 2)).squeeze(1)
retval = dict(
recon_speech=recon_speech,
)
return retval
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
pass
@property
def input_size(self):
return

View File

@ -0,0 +1,14 @@
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
from cosyvoice.modules.hifigan_module.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator
from cosyvoice.modules.hifigan_module.discriminator import MultipleDiscriminator
from cosyvoice.modules.hifigan_module.nsf_utils import ConvRNNF0Predictor

View File

@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@ -0,0 +1,299 @@
"""hifigan based dicriminator implementation.
This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import typing as tp
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv2d, AvgPool1d, Conv1d
from torch.nn.utils import weight_norm, spectral_norm
from funasr.models.llm_asr.hifigan_module import get_padding
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3,
use_spectral_norm=False, lrelu_slope=0.1):
super(DiscriminatorP, self).__init__()
self.period = period
self.lrelu_slope = lrelu_slope
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList([
norm_f(
Conv2d(
1,
32, (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(
Conv2d(
32,
128, (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(
Conv2d(
128,
512, (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(
Conv2d(
512,
1024, (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self,
in_channels: int = 1,
periods: tp.List[int] = [2, 3, 5, 7, 11]):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(p) for p in periods
])
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for f in self.discriminators:
# outs += [f(x)]
if return_intermediates:
outs.append(f(x))
else:
outs.append(f(x)[0])
return outs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False, lrelu_slope=0.1):
super(DiscriminatorS, self).__init__()
self.lrelu_slope = lrelu_slope
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self, in_channels: int = 1, nb_scales: int = 3):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList(
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for i, f in enumerate(self.discriminators):
if i != 0:
x = self.meanpools[i - 1](x)
if return_intermediates:
outs.append(f(x))
else:
outs.append(f(x)[0])
return outs
class DiscriminatorR(nn.Module):
def __init__(
self,
stft_params: tp.List[int],
lrelu_slope: float = 0.1,
use_spectral_norm: bool = False,
):
super().__init__()
self.stft_params = stft_params
self.lrelu_slope = lrelu_slope
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def spectrogram(self, x):
n_fft, hop_length, win_length = self.stft_params
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
x = x.squeeze(1)
spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length,
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT]
return mag
def forward(self, x):
fmap = []
x = self.spectrogram(x).unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
in_channels: int,
fft_sizes: tp.List[int] = [1024, 2048, 512],
hop_sizes: tp.List[int] = [120, 240, 50],
win_lengths: tp.List[int] = [600, 1200, 240],
lrelu_slope: float = 0.1,
):
super().__init__()
self.discriminators = nn.ModuleList()
for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths):
self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope))
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for f in self.discriminators:
if return_intermediates:
outs.append(f(x))
else:
outs.append(f(x)[0])
return outs
class MultipleDiscriminator(nn.Module):
def __init__(
self,
input_size: int = 1,
disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None,
):
super().__init__()
self.support_disc_choices = dict(
mpd=MultiPeriodDiscriminator,
msd=MultiScaleDiscriminator,
mrd=MultiResolutionDiscriminator,
)
self.discriminators = nn.ModuleList()
self.discriminator_type_lst = []
for args in disc_conf_list:
assert "name" in args, "disc_conf must have `name` attr to specific disc type."
disc_type = args.pop("name")
assert disc_type in self.support_disc_choices, \
"Unsupported discriminator type, only support {}".format(
",".join(self.support_disc_choices.keys())
)
disc_class = self.support_disc_choices[disc_type]
one_disc = disc_class(in_channels=input_size, **args)
self.discriminators.append(one_disc)
# add back to the args for dump config.yaml
args["name"] = disc_type
self.discriminator_type_lst.append(disc_type)
def get_discriminator_type_lst(self) -> tp.List[str]:
return self.discriminator_type_lst
def forward(self, x, return_intermediates=True):
retval = []
for disc in self.discriminators:
out = disc(x, return_intermediates=return_intermediates)
if isinstance(out, tuple):
retval.append(out)
elif isinstance(out, list):
retval.extend(out)
else:
raise TypeError("The return value of discriminator must be tuple or list[tuple]")
return retval

View File

@ -0,0 +1,621 @@
"""hifigan based generator implementation.
This code is modified from https://github.com/jik876/hifi-gan
,https://github.com/kan-bayashi/ParallelWaveGAN and
https://github.com/NVIDIA/BigVGAN
"""
import typing as tp
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm
from torch.nn.utils import remove_weight_norm
from funasr.models.llm_asr.hifigan_module import get_padding, init_weights
from funasr.models.llm_asr.hifigan_module.activations import Snake, SnakeBeta
from funasr.models.llm_asr.hifigan_module.nsf_utils import SourceModule, SourceModuleHnNSF
class ResBlock(torch.nn.Module):
"""Residual block module in HiFiGAN/BigVGAN."""
def __init__(
self,
channels: int = 512,
kernel_size: int = 3,
dilations: tp.List[int] = [1, 3, 5],
use_additional_convs: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
):
super(ResBlock, self).__init__()
self.use_additional_convs = use_additional_convs
self.convs1 = nn.ModuleList()
if use_additional_convs:
self.convs2 = nn.ModuleList()
for dilation in dilations:
self.convs1.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
)
)
)
if use_additional_convs:
self.convs2.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
)
)
)
self.convs1.apply(init_weights)
if use_additional_convs:
self.convs2.apply(init_weights)
if nonlinear_activation == "LeakyReLU":
self.activations1 = nn.ModuleList([
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
for _ in range(len(self.convs1))
])
if use_additional_convs:
self.activations2 = nn.ModuleList([
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
for _ in range(len(self.convs2))
])
elif nonlinear_activation == "Snake":
self.activations1 = nn.ModuleList([
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
for _ in range(len(self.convs1))
])
if use_additional_convs:
self.activations2 = nn.ModuleList([
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
for _ in range(len(self.convs2))
])
elif nonlinear_activation == "SnakeBeta":
self.activations1 = nn.ModuleList([
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
for _ in range(len(self.convs1))
])
if use_additional_convs:
self.activations2 = nn.ModuleList([
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
for _ in range(len(self.convs2))
])
else:
raise NotImplementedError
def forward(self, x: torch.Tensor) -> torch.Tensor:
for idx in range(len(self.convs1)):
xt = self.activations1[idx](x)
xt = self.convs1[idx](xt)
if self.use_additional_convs:
xt = self.activations2[idx](xt)
xt = self.convs2[idx](xt)
x = xt + x
return x
def remove_weight_norm(self):
for idx in range(len(self.convs1)):
remove_weight_norm(self.convs1[idx])
if self.use_additional_convs:
remove_weight_norm(self.convs2[idx])
class HifiGenerator(nn.Module):
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
global_channels: int = -1,
upsample_rates: tp.List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
resblock_nonlinear_activation: str = "LeakyReLU",
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
use_additional_convs: bool = True,
cond_in_each_up_layer: bool = False,
lrelu_slope: float = 0.1,
act_pre_each_up_layer: bool = True
):
super(HifiGenerator, self).__init__()
self.out_channels = 1
self.global_channels = global_channels
self.use_additional_convs = use_additional_convs
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
self.lrelu_slope = lrelu_slope
self.act_pre_each_up_layer = act_pre_each_up_layer
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
resblock_nonlinear_activation,
resblock_nonlinear_activation_params))
if self.global_channels > 0:
self.conv_global_cond = weight_norm(
Conv1d(global_channels, base_channels, 1)
)
self.conv_global_cond.apply(init_weights)
if self.cond_in_each_up_layer:
self.conv_conds = nn.ModuleList()
for i in range(len(self.ups)):
self.conv_conds.append(weight_norm(
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
)
self.conv_conds.apply(init_weights)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def output_size(self):
return self.out_channels
def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
# x in (B, in_channels, T), g in (B, global_channels, 1)
x = self.conv_pre(x)
if self.global_channels > 0 and g is not None:
x = x + self.conv_global_cond(g)
for i in range(self.num_upsamples):
if self.act_pre_each_up_layer:
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if self.cond_in_each_up_layer and g is not None:
x = x + self.conv_conds[i](g)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
if self.global_channels > 0:
remove_weight_norm(self.conv_global_cond)
if self.cond_in_each_up_layer:
for l in self.conv_conds:
remove_weight_norm(l)
class NsfHifiGenerator(nn.Module):
"""
Neural Source Filter + HifiGan
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
global_channels: int = -1,
nb_harmonics: int = 7,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: tp.List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
resblock_nonlinear_activation: str = "LeakyReLU",
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
use_additional_convs: bool = True,
cond_in_each_up_layer: bool = False,
lrelu_slope: float = 0.1,
act_pre_each_up_layer: bool = True
):
super(NsfHifiGenerator, self).__init__()
self.out_channels = 1
self.global_channels = global_channels
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.use_additional_convs = use_additional_convs
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
self.lrelu_slope = lrelu_slope
self.act_pre_each_up_layer = act_pre_each_up_layer
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1],
sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold)
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, u in enumerate(downsample_cum_rates[::-1]):
if (u == 1):
self.source_downs.append(
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1))
)
else:
self.source_downs.append(
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)))
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
resblock_nonlinear_activation,
resblock_nonlinear_activation_params))
if self.global_channels > 0:
self.conv_global_cond = weight_norm(
Conv1d(global_channels, base_channels, 1)
)
self.conv_global_cond.apply(init_weights)
if self.cond_in_each_up_layer:
self.conv_conds = nn.ModuleList()
for i in range(len(self.ups)):
self.conv_conds.append(weight_norm(
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
)
self.conv_conds.apply(init_weights)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def output_size(self):
return self.out_channels
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
return self.source_module(f0.unsqueeze(1))
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
s = self._f02source(f0)
x = self.conv_pre(x)
if self.global_channels > 0 and g is not None:
x = x + self.conv_global_cond(g)
for i in range(self.num_upsamples):
if self.act_pre_each_up_layer:
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if self.cond_in_each_up_layer and g is not None:
x = x + self.conv_conds[i](g)
# fusion
x = x + self.source_downs[i](s)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
if self.global_channels > 0:
remove_weight_norm(self.conv_global_cond)
if self.cond_in_each_up_layer:
for l in self.conv_conds:
remove_weight_norm(l)
self.source_module.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
class HiFTGenerator(nn.Module):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
global_channels: int = -1,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: tp.List[int] = [8, 8],
upsample_kernel_sizes: tp.List[int] = [16, 16],
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
resblock_nonlinear_activation: str = "Snake",
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
source_resblock_nonlinear_activation: str = "Snake",
source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
use_additional_convs: bool = True,
cond_in_each_up_layer: bool = False,
lrelu_slope: float = 0.1,
act_pre_each_up_layer: bool = True,
audio_limit: float = 0.99,
):
super(HiFTGenerator, self).__init__()
self.out_channels = 1
self.global_channels = global_channels
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.use_additional_convs = use_additional_convs
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
self.lrelu_slope = lrelu_slope
self.act_pre_each_up_layer = act_pre_each_up_layer
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d,
use_additional_convs, source_resblock_nonlinear_activation,
source_resblock_nonlinear_activation_params)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
resblock_nonlinear_activation,
resblock_nonlinear_activation_params))
if self.global_channels > 0:
self.conv_global_cond = weight_norm(
Conv1d(global_channels, base_channels, 1)
)
self.conv_global_cond.apply(init_weights)
if self.cond_in_each_up_layer:
self.conv_conds = nn.ModuleList()
for i in range(len(self.ups)):
self.conv_conds.append(weight_norm(
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
)
self.conv_conds.apply(init_weights)
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.register_buffer("stft_window", window)
def output_size(self):
return self.out_channels
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, _, _ = self.m_source(f0)
return har_source.transpose(1, 2)
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
s = self._f02source(f0)
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
x = self.conv_pre(x)
if self.global_channels > 0 and g is not None:
x = x + self.conv_global_cond(g)
for i in range(self.num_upsamples):
if self.act_pre_each_up_layer:
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if self.cond_in_each_up_layer and g is not None:
x = x + self.conv_conds[i](g)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
if self.global_channels > 0:
remove_weight_norm(self.conv_global_cond)
if self.cond_in_each_up_layer:
for l in self.conv_conds:
remove_weight_norm(l)
self.source_module.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def _stft(self, x):
spec = torch.stft(
x,
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window,
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[...,0], spec[...,1]
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
inverse_transform = torch.istft(
torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1),
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window)
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation

View File

@ -0,0 +1,93 @@
import torch
import torch.utils.data
import numpy as np
from librosa.filters import mel as librosa_mel_fn
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = spectral_normalize_torch(spec)
return spec
def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
global mel_basis, hann_window
spec = spectral_de_normalize_torch(spec)
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
spec = spectral_normalize_torch(spec)
return spec

View File

@ -0,0 +1,253 @@
"""
Neural Source Filter based modules implementation.
Neural source-filter waveform models for statistical parametric speech synthesis
"""
import numpy as np
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
@torch.no_grad()
def forward(self, f0):
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
# generate sine waveforms
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2))
sine_wavs = sine_wavs.transpose(1,2)
uv = uv.transpose(1,2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class SourceModule(torch.nn.Module):
def __init__(self,
nb_harmonics: int,
upsample_ratio: int,
sampling_rate: int,
alpha: float = 0.1,
sigma: float = 0.003,
voiced_threshold: float = 10
):
super(SourceModule, self).__init__()
self.nb_harmonics = nb_harmonics
self.upsample_ratio = upsample_ratio
self.sampling_rate = sampling_rate
self.alpha = alpha
self.sigma = sigma
self.voiced_threshold = voiced_threshold
self.ffn = nn.Sequential(
weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
nn.Tanh())
def f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
def forward(self, f0):
"""
:param f0: [B, 1, frame_len], Hz
:return: [B, 1, sample_len]
"""
with torch.no_grad():
uv = self.f02uv(f0)
f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest')
uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest')
F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device)
for i in range(self.nb_harmonics + 1):
F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
n_dist = Normal(loc=0., scale=self.sigma)
noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device)
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
e_unvoice = self.alpha / 3 / self.sigma * noise
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
return self.ffn(e)
def remove_weight_norm(self):
remove_weight_norm(self.ffn[0])
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512,
use_cond_rnn: bool = True,
bidirectional_rnn: bool = False,
):
super().__init__()
self.num_class = num_class
self.use_cond_rnn = use_cond_rnn
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
if self.use_cond_rnn:
self.rnn = nn.GRU(
cond_channels,
cond_channels // 2 if bidirectional_rnn else cond_channels,
num_layers=1,
batch_first=True,
bidirectional=bidirectional_rnn,
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
if self.use_cond_rnn:
x, _ = self.rnn(x.transpose(1, 2))
else:
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))

View File

@ -0,0 +1,93 @@
import torch
import torch.utils.data
import numpy as np
from librosa.filters import mel as librosa_mel_fn
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = spectral_normalize_torch(spec)
return spec
def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
global mel_basis, hann_window
spec = spectral_de_normalize_torch(spec)
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
spec = spectral_normalize_torch(spec)
return spec

View File

@ -1,4 +1,6 @@
import logging
import os.path
import torchaudio
from typing import Union, Dict, List, Tuple, Optional
import time
@ -1580,6 +1582,29 @@ class LLMASR5(nn.Module):
reduction=False,
)
mel_decoder_name = kwargs.get("mel_decoder", None)
mel_decoder_conf = kwargs.get("mel_decoder_conf", None)
self.mel_decoder = self.build_mel_decoder(name=mel_decoder_name, conf=mel_decoder_conf)
vocoder_name = kwargs.get("vocoder", None)
vocoder_conf = kwargs.get("vocoder_conf", None)
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf)
def build_mel_decoder(self, name: str, conf: dict):
if name is None or conf is None:
return None
if name == "MaskedDiffWithXvec":
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
return MaskedDiffWithXvec(**conf)
return None
def build_vocoder(self, name: str, conf: dict):
if name is None or conf is None:
return None
if name == "HifiGAN":
from funasr.models.llm_asr.hifigan import HifiGan
return HifiGan(**conf)
return None
def build_audio_decoder(self, name, conf):
if name == "transformer":
from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM
@ -2205,7 +2230,16 @@ class LLMASR5(nn.Module):
target_ids = generated_ids["sequences"]
target_emb = self.llm.model.get_input_embeddings()(target_ids)
if self.concat_emb_hidden:
if not self.concat_emb_hidden_norm:
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
else:
outs = self.hidden_norm(hidden_states_select)
outs = self.fusion_dropout(self.fusion_act(outs))
# emb = model_outputs.hidden_states[0]
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
:, :, 0
@ -2221,12 +2255,18 @@ class LLMASR5(nn.Module):
loss = None
# synthesize waveforms
spk_emb = kwargs.get("spk_emb", None)
feat, wav = self.synthesize_waveform(speech_tokens, spk_emb, inputs_embeds.device)
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{0 + 1}best_recog"]
self.write_mel_wav(kwargs.get("output_dir"), feat, wav, key[0])
results = []
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {
@ -2253,6 +2293,48 @@ class LLMASR5(nn.Module):
return results, meta_data
def write_mel_wav(self, output_dir, feat, wav, key):
out_dir = os.path.join(output_dir, "1best_recog", "mels")
os.makedirs(out_dir, exist_ok=True)
if feat is not None:
feat = feat.cpu().numpy()[0]
np.save(os.path.join(out_dir, f"{key}.npy"), feat)
out_dir = os.path.join(output_dir, "1best_recog", "wavs")
os.makedirs(out_dir, exist_ok=True)
if wav is not None:
path = os.path.join(out_dir, f"{key}.wav")
torchaudio.save(
path, wav[0], sample_rate=self.vocoder.sample_rate,
encoding='PCM_S', bits_per_sample=16
)
def synthesize_waveform(self, speech_tokens, spk_emb, device):
mel_feat, wav = None, None
if self.mel_decoder is not None and spk_emb is not None:
# mel_feat in BxCxT
mel_feat = self.token2mel(speech_tokens, spk_emb, device)
if self.vocoder is not None:
wav = self.vocoder.inference(mel_feat.transpose(1, 2))
return mel_feat, wav
def token2mel(self, tokens: torch.Tensor, xvec: torch.Tensor, device: torch.device):
xvec = torch.tensor(xvec).to(device).unsqueeze(0)
xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64)
token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64)
feat = self.mel_decoder.inference(
tokens, token_lens,
xvec, xvec_lens,
diff_steps=10,
temperature=1.0,
prompt=dict(
prompt_text=(None, None),
prompt_audio=(None, None)
)
)
return feat
def audio_decode(
self,
text: torch.Tensor,
@ -2263,9 +2345,8 @@ class LLMASR5(nn.Module):
decoding_length=None,
):
# 1. encode text
text = self.audio_decoder_in_proj(text)
# text = self.audio_decoder_in_proj(text)
device = text.device
out_tokens = []
sos_eos_emb = self.audio_decoder_embedding(
torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)
)
@ -2273,30 +2354,18 @@ class LLMASR5(nn.Module):
torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)
)
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
state, cfg_state = None, None
seq_input = torch.zeros(
[1, prompt.shape[1] + max_length, prompt.shape[2]],
dtype=torch.float32, device=device
)
seq_input[:, :prompt.shape[1], :] = prompt
out_tokens = torch.zeros([1, max_length, 1], device=device)
out_token_len = 0
prompt_len = prompt.shape[1]
state, hit_eos = None, False
for i in range(max_length):
if len(out_tokens) > 0:
codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device)
codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device)
# if any quantizer output is eos
if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)):
break
seq_input, _ = self.prepare_audio_decoder_io(
text, text_lengths, codec_prompt, codec_lengths, need_targets=False
)
else:
seq_input, _ = self.prepare_audio_decoder_io(
text, text_lengths, None, None, need_targets=False
)
# use state for speedup
pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0])
if infer_cfg_ratio is not None:
cond_len = prompt[0].shape[0]
cfg_pred, (cfg_state, _) = self.audio_decoder.score(
seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :]
)
pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred
# sampling all `nq` token ids
pred = pred.reshape(self.predict_nq, -1)
@ -2304,49 +2373,44 @@ class LLMASR5(nn.Module):
pred = torch.log_softmax(pred, dim=-1)
if min_length is not None and i < min_length:
pred[:, self.codebook_size + self.ad_sos_eos] = float(np.finfo(np.float32).min)
top_ids = []
for k in range(self.predict_nq):
top_ids.append(self.ras_sampling(pred[k], out_tokens)[0].item())
out_tokens.append(top_ids)
top_ids = self.ras_sampling(pred[0], out_tokens[0])
out_tokens[0, out_token_len, 0] = top_ids[0]
seq_input[0, prompt_len + out_token_len, :] = self.codec_embedder(top_ids)[0]
out_token_len += 1
# remove eos token
hit_eos = False
if torch.any(
torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos
):
if torch.any(out_tokens[:, out_token_len - 1] == (self.codebook_size + self.ad_sos_eos)):
hit_eos = True
out_tokens = out_tokens[:-1]
out_tokens = out_tokens[:, :out_token_len, :]
break
if decoding_length is None:
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
return out_tokens
else:
return torch.tensor([out_tokens], dtype=torch.int64, device=device), hit_eos
return out_tokens, hit_eos
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(
self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1
):
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
rep_num = torch.sum(decoded_tokens[-win_size:] == top_ids).item()
if rep_num >= win_size * tau_r:
top_ids = self.random_sampling(weighted_scores)
return top_ids
def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
i = len(sorted_idx)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
if cum_prob < top_p and i < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
prob = sorted_value[:i]
indices = sorted_idx[:i]
sampling_ids = prob.multinomial(1, replacement=True)
top_ids = indices[sampling_ids]
return top_ids

View File

@ -334,3 +334,45 @@ class MaskAlongAxisLFR(torch.nn.Module):
replace_with_zero=self.replace_with_zero,
lfr_rate=self.lfr_rate,
)
class PrefixMaskVariableMaxWidth(torch.nn.Module):
def __init__(
self,
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
replace_value: float = 0.0,
):
super().__init__()
self.mask_width_ratio_range = mask_width_ratio_range
self.replace_value = replace_value
def extra_repr(self):
return (
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
)
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False):
bb, tt, _ = spec.shape
mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device)
mask_width_range = (mask_width_ratio_range * tt).long()
mask_length = torch.randint(
mask_width_range[0],
mask_width_range[1],
(bb, 1),
device=spec.device,
).unsqueeze(2)
# mask_pos: (B, num_mask, 1)
mask_pos = tt - mask_length
aran = torch.arange(tt, device=spec.device)[None, None, :]
# mask: (Batch, num_mask, L)
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
# Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1)
mask = mask.any(dim=1).unsqueeze(2)
spec = spec.masked_fill(mask, self.replace_value)
if return_mask:
return spec, spec_lengths, mask
return spec, spec_lengths

13
funasr/utils/hinter.py Normal file
View File

@ -0,0 +1,13 @@
import sys
import torch.distributed
import logging
HINTED = set()
def hint_once(content, uid, rank=None):
if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
if uid not in HINTED:
logging.info(content)
HINTED.add(uid)