mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
refine decoding process, merge flow and vocoder
This commit is contained in:
parent
e49e54596c
commit
573ae881cd
628
funasr/models/llm_asr/conformer_encoder.py
Normal file
628
funasr/models/llm_asr/conformer_encoder.py
Normal file
@ -0,0 +1,628 @@
|
||||
# Copyright 2020 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Conformer encoder definition."""
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from funasr.models.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
LegacyRelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
|
||||
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.models.transformer.utils.nets_utils import get_activation
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.utils.mask import subsequent_mask
|
||||
from funasr.models.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.utils.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError,
|
||||
check_short_utt, Conv2dSubsamplingPad
|
||||
)
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
||||
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
||||
- w/o pos emb: Tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
return x, mask
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(x)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
|
||||
self.feed_forward(x)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = "legacy",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
||||
causal: bool = False,
|
||||
skip: bool = False,
|
||||
channel_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.causal = causal
|
||||
self.skip = skip
|
||||
self.channel_first = channel_first
|
||||
|
||||
if rel_pos_type == "legacy":
|
||||
if pos_enc_layer_type == "rel_pos":
|
||||
pos_enc_layer_type = "legacy_rel_pos"
|
||||
if selfattention_layer_type == "rel_selfattn":
|
||||
selfattention_layer_type = "legacy_rel_selfattn"
|
||||
elif rel_pos_type == "latest":
|
||||
assert selfattention_layer_type != "legacy_rel_selfattn"
|
||||
assert pos_enc_layer_type != "legacy_rel_pos"
|
||||
else:
|
||||
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
"Using legacy_rel_pos and it will be deprecated in the future."
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2dpad":
|
||||
self.embed = Conv2dSubsamplingPad(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
|
||||
if isinstance(stochastic_depth_rate, float):
|
||||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
|
||||
|
||||
if len(stochastic_depth_rate) != num_blocks:
|
||||
raise ValueError(
|
||||
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
|
||||
f"should be equal to num_blocks ({num_blocks})"
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate[lnum],
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
raw_input = xs_pad
|
||||
if self.channel_first:
|
||||
xs_pad = xs_pad.permute(0, 2, 1)
|
||||
|
||||
if ilens is not None:
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
else:
|
||||
masks = torch.ones(xs_pad.shape[0], 1, xs_pad.shape[1],
|
||||
dtype=torch.bool, device=xs_pad.device)
|
||||
if self.causal:
|
||||
causal_mask = subsequent_mask(
|
||||
xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype
|
||||
).unsqueeze(0)
|
||||
masks = masks & causal_mask
|
||||
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
or isinstance(self.embed, Conv2dSubsamplingPad)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
x, pos_emb = xs_pad
|
||||
x = x + self.conditioning_layer(ctc_out)
|
||||
xs_pad = (x, pos_emb)
|
||||
else:
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.channel_first:
|
||||
xs_pad = xs_pad.permute(0, 2, 1)
|
||||
|
||||
if self.skip:
|
||||
xs_pad = xs_pad + raw_input
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
|
||||
if ilens is not None:
|
||||
return xs_pad, olens, None
|
||||
else:
|
||||
return xs_pad
|
||||
0
funasr/models/llm_asr/diffusion_models/__init__.py
Normal file
0
funasr/models/llm_asr/diffusion_models/__init__.py
Normal file
178
funasr/models/llm_asr/diffusion_models/flow_matching.py
Normal file
178
funasr/models/llm_asr/diffusion_models/flow_matching.py
Normal file
@ -0,0 +1,178 @@
|
||||
from abc import ABC
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, ConditionalDecoder)
|
||||
import logging
|
||||
from funasr.utils.hinter import hint_once
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
n_feats: int,
|
||||
cfm_params: dict,
|
||||
n_spks: int = 1,
|
||||
spk_emb_dim: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_feats = n_feats
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.solver = cfm_params.get("solver", "euler")
|
||||
self.sigma_min = cfm_params.get("sigma_min", 1e-4)
|
||||
|
||||
self.estimator = None
|
||||
self.t_scheduler = cfm_params.get("t_scheduler", "linear")
|
||||
self.training_cfg_rate = cfm_params.get("training_cfg_rate", 0.0)
|
||||
self.inference_cfg_rate = cfm_params.get("inference_cfg_rate", 0.0)
|
||||
self.reg_loss_type = cfm_params.get("reg_loss_type", "l2")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
steps = 1
|
||||
while steps <= len(t_span) - 1:
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
cfg_dphi_dt = self.estimator(
|
||||
x, mask,
|
||||
torch.zeros_like(mu), t,
|
||||
torch.zeros_like(spks) if spks is not None else None,
|
||||
torch.zeros_like(cond)
|
||||
)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
||||
self.inference_cfg_rate * cfg_dphi_dt)
|
||||
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if steps < len(t_span) - 1:
|
||||
dt = t_span[steps + 1] - t
|
||||
steps += 1
|
||||
|
||||
return sol[-1]
|
||||
|
||||
def calc_reg_loss(self, prediction, target, loss_mask):
|
||||
if self.reg_loss_type == 'l1':
|
||||
hint_once("use l1 loss to train CFM", "CFM_LOSS_L1")
|
||||
l1_loss = F.l1_loss(prediction, target, reduction="none")
|
||||
l1_loss = l1_loss * loss_mask
|
||||
return l1_loss
|
||||
elif self.reg_loss_type == 'l2':
|
||||
hint_once("use l2 loss to train CFM", "CFM_LOSS_L2")
|
||||
l2_loss = F.mse_loss(prediction, target, reduction="none")
|
||||
l2_loss = l2_loss * loss_mask
|
||||
return l2_loss
|
||||
else:
|
||||
hint_once("use l1+l2 loss to train CFM", "CFM_LOSS_L1_L2")
|
||||
l1_loss = F.l1_loss(prediction, target, reduction="none")
|
||||
l1_loss = l1_loss * loss_mask
|
||||
l2_loss = 0.5 * F.mse_loss(prediction, target, reduction="none")
|
||||
l2_loss = l2_loss * loss_mask
|
||||
return l1_loss * 0.5 + l2_loss * 0.5
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, reduction='none'):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
|
||||
# random timestep
|
||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
||||
# sample noise p(x_0)
|
||||
z = torch.randn_like(x1)
|
||||
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask
|
||||
if spks is not None:
|
||||
spks = spks * cfg_mask.squeeze(-1)
|
||||
if cond is not None:
|
||||
cond = cond * cfg_mask
|
||||
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
loss = self.calc_reg_loss(pred, u, mask)
|
||||
if reduction == "mean":
|
||||
loss = loss.sum() / (torch.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CFM(BASECFM):
|
||||
def __init__(self, in_channels, out_channel, cfm_params, decoder_params,
|
||||
n_spks=1, spk_emb_dim=64, decoder_name="Decoder"):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
if decoder_name == "Decoder":
|
||||
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
||||
else:
|
||||
self.estimator = ConditionalDecoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
||||
219
funasr/models/llm_asr/diffusion_models/length_regulator.py
Normal file
219
funasr/models/llm_asr/diffusion_models/length_regulator.py
Normal file
@ -0,0 +1,219 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.llm_asr.diffusion_models.matcha_decoder import Upsample1D, Downsample1D
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
||||
from einops import repeat, pack
|
||||
import logging
|
||||
|
||||
|
||||
class UpSamplingRegulator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
sampling_ratios: Tuple,
|
||||
out_channels: int = None,
|
||||
groups: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_ratios = sampling_ratios
|
||||
out_channels = out_channels or channels
|
||||
model = nn.ModuleList([])
|
||||
if len(sampling_ratios) > 0:
|
||||
for ratio in sampling_ratios:
|
||||
if ratio > 1:
|
||||
module = Upsample1D(channels=channels, channel_first=False)
|
||||
else:
|
||||
module = nn.Linear(channels, channels)
|
||||
norm = nn.LayerNorm(channels)
|
||||
act = nn.LeakyReLU()
|
||||
model.extend([module, norm, act])
|
||||
model.append(
|
||||
nn.Linear(channels, out_channels)
|
||||
)
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x, xlens, y=None, y_lens=None, cond=None):
|
||||
# x, out, y in (B, T, D)
|
||||
out = self.model(x)
|
||||
out = out[:, :y.shape[1]]
|
||||
olens = y_lens
|
||||
|
||||
return out, olens
|
||||
|
||||
|
||||
class DownSamplingRegulator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
sampling_ratios: Tuple,
|
||||
out_channels: int = None,
|
||||
groups: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_ratios = sampling_ratios
|
||||
out_channels = out_channels or channels
|
||||
model = nn.ModuleList([])
|
||||
if len(sampling_ratios) > 0:
|
||||
for ratio in sampling_ratios:
|
||||
if ratio > 1:
|
||||
module = Downsample1D(dim=channels, channel_first=False, padding=2)
|
||||
else:
|
||||
module = nn.Linear(channels, channels)
|
||||
norm = nn.LayerNorm(channels)
|
||||
act = nn.LeakyReLU()
|
||||
model.extend([module, norm, act])
|
||||
|
||||
model.append(
|
||||
nn.Linear(channels, out_channels)
|
||||
)
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x, xlens, y=None, y_lens=None, cond=None):
|
||||
# x, out, y in (B, T, D)
|
||||
out = self.model(x)
|
||||
out = out[:, :y.shape[1]]
|
||||
olens = y_lens
|
||||
|
||||
return out, olens
|
||||
|
||||
|
||||
class InterpolateRegulator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
sampling_ratios: Tuple,
|
||||
out_channels: int = None,
|
||||
groups: int = 1,
|
||||
mode="nearest",
|
||||
align_corners=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_ratios = sampling_ratios
|
||||
out_channels = out_channels or channels
|
||||
model = nn.ModuleList([])
|
||||
if len(sampling_ratios) > 0:
|
||||
for _ in sampling_ratios:
|
||||
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
||||
norm = nn.GroupNorm(groups, channels)
|
||||
act = nn.Mish()
|
||||
model.extend([module, norm, act])
|
||||
|
||||
model.append(
|
||||
nn.Conv1d(channels, out_channels, 1, 1)
|
||||
)
|
||||
self.model = nn.Sequential(*model)
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x, xlens, y=None, ylens=None, cond=None):
|
||||
# x in (B, T, D)
|
||||
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
||||
align_corners_opt = {}
|
||||
if self.mode in ["linear", "bilinear","bicubic", "trilinear"]:
|
||||
align_corners_opt = dict(align_corners=self.align_corners)
|
||||
x = F.interpolate(x.transpose(1, 2).contiguous(), size=y.shape[1],
|
||||
mode=self.mode, **align_corners_opt)
|
||||
out = self.model(x).transpose(1, 2).contiguous()
|
||||
olens = ylens
|
||||
|
||||
return out * mask, olens
|
||||
|
||||
|
||||
class UpsamplingBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride=2,
|
||||
groups=1,
|
||||
channel_first=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.stride = stride
|
||||
|
||||
self.up_conv = nn.ConvTranspose1d(in_channels, channels, stride * 2, stride, 1)
|
||||
self.block1 = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(channels, channels, 3, padding=1),
|
||||
torch.nn.GroupNorm(groups, channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.block2 = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(channels, channels, 3, padding=1),
|
||||
torch.nn.GroupNorm(groups, channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.res_conv = torch.nn.Conv1d(channels, channels, 1)
|
||||
|
||||
def forward(self, x, ilens):
|
||||
if not self.channel_first:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
olens = ilens * self.stride
|
||||
o_masks = (~make_pad_mask(olens))[:, None, :].to(x)
|
||||
res = out = self.up_conv(x) * o_masks
|
||||
|
||||
out = self.block1(out) * o_masks + out
|
||||
out = self.block2(out) * o_masks + out
|
||||
out = out + self.res_conv(res) * o_masks
|
||||
|
||||
if not self.channel_first:
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out, olens
|
||||
|
||||
|
||||
class RepeatLengthRegulator(torch.nn.Module):
|
||||
"""Repeat Length regulator module for feed-forward Transformer.
|
||||
|
||||
This is a module of length regulator described in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
The length regulator expands char or
|
||||
phoneme-level embedding features to frame-level by repeating each
|
||||
feature based on the corresponding predicted durations.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, pad_value=0.0):
|
||||
"""Initilize length regulator module.
|
||||
|
||||
Args:
|
||||
pad_value (float, optional): Value used for padding.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.pad_value = pad_value
|
||||
|
||||
def forward(self, xs, ds, alpha=1.0):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
|
||||
ds (LongTensor): Batch of durations of each frame (B, T).
|
||||
alpha (float, optional): Alpha value to control speed of speech.
|
||||
|
||||
Returns:
|
||||
Tensor: replicated input tensor based on durations (B, T*, D).
|
||||
|
||||
"""
|
||||
if alpha != 1.0:
|
||||
assert alpha > 0
|
||||
ds = torch.round(ds.float() * alpha).long()
|
||||
|
||||
if ds.sum() == 0:
|
||||
logging.warning(
|
||||
"predicted durations includes all 0 sequences. "
|
||||
"fill the first element with 1."
|
||||
)
|
||||
# NOTE(kan-bayashi): This case must not be happened in teacher forcing.
|
||||
# It will be happened in inference with a bad duration predictor.
|
||||
# So we do not need to care the padded sequence case here.
|
||||
ds[ds.sum(dim=1).eq(0)] = 1
|
||||
|
||||
repeat = [torch.repeat_interleave(x, d, dim=0) for x, d in zip(xs, ds)]
|
||||
return pad_list(repeat, self.pad_value)
|
||||
844
funasr/models/llm_asr/diffusion_models/matcha_decoder.py
Normal file
844
funasr/models/llm_asr/diffusion_models/matcha_decoder.py
Normal file
@ -0,0 +1,844 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.activations import get_activation
|
||||
from einops import pack, rearrange, repeat
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.llm_asr.diffusion_models.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Block1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
||||
torch.nn.GroupNorm(groups, dim_out),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class ResnetBlock1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
|
||||
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
def __init__(self, dim, channel_first=True, padding=1):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, padding)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.channel_first:
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
|
||||
out = self.conv(x)
|
||||
|
||||
if not self.channel_first:
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True,
|
||||
out_channels=None, name="conv", channel_first=True, stride=2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.channel_first = channel_first
|
||||
self.stride = stride
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, stride*2, stride, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
if not self.channel_first:
|
||||
inputs = inputs.transpose(1, 2).contiguous()
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
if not self.channel_first:
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
return outputs
|
||||
|
||||
|
||||
class ConformerWrapper(ConformerBlock):
|
||||
def __init__( # pylint: disable=useless-super-delegation
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
conv_expansion_factor=2,
|
||||
conv_kernel_size=31,
|
||||
attn_dropout=0,
|
||||
ff_dropout=0,
|
||||
conv_dropout=0,
|
||||
conv_causal=False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
**kwargs
|
||||
):
|
||||
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
||||
|
||||
|
||||
class TransformerDecoderWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
dropout: float = 0.1,
|
||||
activation_fn: str = "snakebeta",
|
||||
cond_dim: int = 80,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
attn_dim = num_attention_heads * attention_head_dim
|
||||
self.input_proj = torch.nn.Linear(dim, attn_dim)
|
||||
self.cond_proj = torch.nn.Linear(cond_dim, attn_dim)
|
||||
self.output_proj = torch.nn.Linear(attn_dim, dim)
|
||||
from funasr.models.transformer.decoder import (
|
||||
DecoderLayer, MultiHeadedAttention, PositionwiseFeedForward
|
||||
)
|
||||
self.decoder_layer = DecoderLayer(
|
||||
attn_dim,
|
||||
MultiHeadedAttention(
|
||||
num_attention_heads, attn_dim, dropout
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
num_attention_heads, attn_dim, dropout
|
||||
),
|
||||
PositionwiseFeedForward(attn_dim, attn_dim * 4, dropout),
|
||||
dropout,
|
||||
normalize_before=True,
|
||||
concat_after=concat_after,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
timestep: torch.Tensor=None,
|
||||
prompt: torch.Tensor = None,
|
||||
prompt_lengths: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
# make masks and forward attention layer
|
||||
x_mask = attention_mask[:, :1].transpose(1, 2).contiguous()
|
||||
prompt_mask = (~make_pad_mask(prompt_lengths)[:, :, None]).to(hidden_states.device)
|
||||
x = self.input_proj(hidden_states) * x_mask
|
||||
prompt = self.cond_proj(prompt) * prompt_mask
|
||||
x = self.decoder_layer(x, x_mask.transpose(1, 2).contiguous(), prompt, prompt_mask.transpose(1, 2).contiguous())[0]
|
||||
x = self.output_proj(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
conditions: dict = None,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conditions = conditions
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.down_block_type = down_block_type
|
||||
self.mid_block_type = mid_block_type
|
||||
self.up_block_type = up_block_type
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
if conditions is not None and "xvec" in conditions:
|
||||
input_channel = input_channel + conditions["xvec"]["dim"]
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
concat_after=concat_after,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
|
||||
if conditions is not None and "xvec" in conditions:
|
||||
input_channel = input_channel + conditions["xvec"]["dim"]
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
concat_after=concat_after,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
|
||||
if conditions is not None and "xvec" in conditions:
|
||||
input_channel = input_channel + conditions["xvec"]["dim"]
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
concat_after=concat_after,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
# nn.init.normal_(self.final_proj.weight)
|
||||
|
||||
def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
elif block_type == "transformer_decoder":
|
||||
block = TransformerDecoderWrapper(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
cond_dim=self.conditions["prompt"]["dim"],
|
||||
concat_after=kwargs.get("concat_after", False)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
if (cond is not None and "xvec" in cond and
|
||||
self.conditions is not None and "xvec" in self.conditions):
|
||||
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, xvec], "b * t")[0]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_down = rearrange(mask_down, "b 1 t -> b t")
|
||||
if self.down_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
||||
else:
|
||||
attn_mask = mask_down.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
cond_kwargs = {}
|
||||
if (cond is not None and "prompt" in cond and
|
||||
self.conditions is not None and "prompt" in self.conditions):
|
||||
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
**cond_kwargs
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_down = rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
if (cond is not None and "xvec" in cond and
|
||||
self.conditions is not None and "xvec" in self.conditions):
|
||||
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, xvec], "b * t")[0]
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
||||
if self.mid_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
||||
else:
|
||||
attn_mask = mask_mid.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
cond_kwargs = {}
|
||||
if (cond is not None and "prompt" in cond and
|
||||
self.conditions is not None and "prompt" in self.conditions):
|
||||
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
**cond_kwargs
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
if (cond is not None and "xvec" in cond and
|
||||
self.conditions is not None and "xvec" in self.conditions):
|
||||
xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, xvec], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_up = rearrange(mask_up, "b 1 t -> b t")
|
||||
if self.up_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
||||
else:
|
||||
attn_mask = mask_up.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
cond_kwargs = {}
|
||||
if (cond is not None and "prompt" in cond and
|
||||
self.conditions is not None and "prompt" in self.conditions):
|
||||
cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"]
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
**cond_kwargs
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_up = rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
|
||||
return output * mask
|
||||
|
||||
|
||||
class ConditionalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.down_block_type = down_block_type
|
||||
self.mid_block_type = mid_block_type
|
||||
self.up_block_type = up_block_type
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
# nn.init.normal_(self.final_proj.weight)
|
||||
|
||||
def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_down = rearrange(mask_down, "b 1 t -> b t")
|
||||
if self.down_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
||||
else:
|
||||
attn_mask = mask_down.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_down = rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
||||
if self.mid_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
||||
else:
|
||||
attn_mask = mask_mid.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
# mask_up = rearrange(mask_up, "b 1 t -> b t")
|
||||
if self.up_block_type == "transformer":
|
||||
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
||||
else:
|
||||
attn_mask = mask_up.squeeze(1)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
# mask_up = rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
|
||||
return output * mask
|
||||
317
funasr/models/llm_asr/diffusion_models/transformer.py
Normal file
317
funasr/models/llm_asr/diffusion_models/transformer.py
Normal file
@ -0,0 +1,317 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.attention import (
|
||||
GEGLU,
|
||||
GELU,
|
||||
AdaLayerNorm,
|
||||
AdaLayerNormZero,
|
||||
ApproximateGELU,
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
"""
|
||||
x = self.proj(x)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(self.alpha)
|
||||
beta = torch.exp(self.beta)
|
||||
else:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim)
|
||||
elif activation_fn == "snakebeta":
|
||||
act_fn = SnakeBeta(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
# scale_qk=False, # uncomment this to not to use flash attention
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
847
funasr/models/llm_asr/flow_matching.py
Normal file
847
funasr/models/llm_asr/flow_matching.py
Normal file
@ -0,0 +1,847 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
import random
|
||||
from funasr.models.llm_asr.mel_spectrum import (
|
||||
mel_spectrogram, power_spectrogram, mel_from_power_spectrogram
|
||||
)
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.transformer.utils.nets_utils import pad_list
|
||||
from distutils.version import LooseVersion
|
||||
from contextlib import contextmanager
|
||||
from funasr.utils.hinter import hint_once
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class MelSpectrumExtractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
sampling_rate=22050,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8000,
|
||||
spec_type="mel",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.sampling_rate = sampling_rate
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.spec_type = spec_type
|
||||
|
||||
def extra_repr(self):
|
||||
return f"n_fft={self.n_fft}, num_mels={self.num_mels}, sampling_rate={self.sampling_rate}, " \
|
||||
f"hop_size={self.hop_size}, win_size={self.win_size}, fmin={self.fmin}, fmax={self.fmax}"
|
||||
|
||||
def forward(self, x, ilens):
|
||||
if self.spec_type == "power":
|
||||
feat = power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
|
||||
self.hop_size, self.win_size, self.fmin, self.fmax)
|
||||
else:
|
||||
feat = mel_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
|
||||
self.hop_size, self.win_size, self.fmin, self.fmax)
|
||||
# determine olens by compare the lengths of inputs and outputs
|
||||
olens = ilens // (x.shape[1] // feat.shape[2])
|
||||
return feat.transpose(1, 2), olens
|
||||
|
||||
def convert_power_to_mel(self, x, ilens):
|
||||
feat = mel_from_power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate,
|
||||
self.hop_size, self.win_size, self.fmin, self.fmax)
|
||||
return feat, ilens
|
||||
|
||||
|
||||
class QuantizerCodebook(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_quantizers,
|
||||
codebook_size,
|
||||
codebook_dim,
|
||||
hop_length,
|
||||
sampling_rate
|
||||
):
|
||||
super().__init__()
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.hop_size = hop_length
|
||||
self.sampling_rate = sampling_rate
|
||||
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
||||
self.register_buffer("embed", embed)
|
||||
codec_index_shift = 1024 * torch.arange(32, dtype=torch.float32)[None, None, :]
|
||||
self.register_buffer("codec_index_shift", codec_index_shift)
|
||||
|
||||
def save_embedding(self, file_name, dense_emb, emb_lengths):
|
||||
import kaldiio
|
||||
wav_writer = kaldiio.WriteHelper("ark,scp,f:{}.ark,{}.scp".format(file_name, file_name))
|
||||
dense_emb = dense_emb.cpu().numpy()
|
||||
for i in range(min(dense_emb.shape[0], 10)):
|
||||
wav_writer(str(i), dense_emb[i, :emb_lengths[i]])
|
||||
|
||||
wav_writer.close()
|
||||
|
||||
def forward(self, codec: torch.Tensor, codec_lengths, return_subs=False):
|
||||
if len(codec.shape) == 2:
|
||||
codec = codec.unsqueeze(-1)
|
||||
bz, tt, nq = codec.shape[0], codec.shape[1], codec.shape[2]
|
||||
codec_mask = ~make_pad_mask(codec_lengths, maxlen=codec.shape[1]).unsqueeze(-1).to(codec.device)
|
||||
codec = codec * codec_mask + self.codec_index_shift[:, :, :nq].long()
|
||||
codec = codec.reshape(-1, nq)
|
||||
emb = self.embed.reshape(-1, self.codebook_dim)
|
||||
codec_emb = F.embedding(codec, emb) # (BT, Nq, D)
|
||||
dense_emb = codec_emb.sum(dim=1)
|
||||
dense_emb = dense_emb.reshape(bz, tt, self.codebook_dim)
|
||||
if return_subs:
|
||||
sub_embs = codec_emb.reshape(bz, tt, nq, self.codebook_dim) * codec_mask.unsqueeze(-2)
|
||||
return (dense_emb * codec_mask, sub_embs), codec_lengths
|
||||
return dense_emb * codec_mask, codec_lengths
|
||||
|
||||
|
||||
class BaseDiffWithXvec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 80,
|
||||
xvec_size: int = 198,
|
||||
output_type: str = "mel",
|
||||
encoder_conf: Dict = {},
|
||||
decoder_conf: Dict = {},
|
||||
mel_feat_conf: Dict = {},
|
||||
codec_conf: Dict = {},
|
||||
length_regulator_conf: Dict = None,
|
||||
prompt_conf: Dict = None,
|
||||
vocab_size: int = None,
|
||||
token_list: List = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.encoder_conf = encoder_conf
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.token_list = token_list
|
||||
self.output_type = output_type
|
||||
self.prompt_conf = prompt_conf
|
||||
self.input_frame_rate = kwargs.get("input_frame_rate", 50)
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
if output_type == 'mel':
|
||||
self.mel_extractor = MelSpectrumExtractor(**mel_feat_conf)
|
||||
elif output_type == 'codec':
|
||||
num_quantizers = codec_conf.get("num_quantizers", 32)
|
||||
codebook_size = codec_conf.get("codebook_size", 1024)
|
||||
codebook_dim = codec_conf.get("codebook_dim", 128)
|
||||
hop_length = codec_conf.get("hop_length", 640)
|
||||
sampling_rate = codec_conf.get("sampling_rate", 16000)
|
||||
self.quantizer_codebook = QuantizerCodebook(num_quantizers, codebook_size, codebook_dim,
|
||||
hop_length, sampling_rate)
|
||||
if vocab_size is not None and vocab_size > 0:
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.xvec_proj = torch.nn.Linear(xvec_size, output_size)
|
||||
self.encoder = self.build_encoder()
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
|
||||
self.decoder = self.build_decoder()
|
||||
|
||||
self.length_regulator_conf = length_regulator_conf
|
||||
self.length_regulator = self.build_length_regulator()
|
||||
|
||||
def build_encoder(self):
|
||||
encoder_name = self.encoder_conf.pop("name", "transformer")
|
||||
model = None
|
||||
if encoder_name == "transformer":
|
||||
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
|
||||
model = ConformerEncoder(
|
||||
**self.encoder_conf,
|
||||
input_size=self.input_size,
|
||||
use_cnn_module=False,
|
||||
macaron_style=False,
|
||||
)
|
||||
elif encoder_name == "conformer":
|
||||
from funasr.models.llm_asr.conformer_encoder import ConformerEncoder
|
||||
model = ConformerEncoder(
|
||||
**self.encoder_conf,
|
||||
input_size=self.input_size,
|
||||
)
|
||||
|
||||
self.encoder_conf["name"] = encoder_name
|
||||
|
||||
return model
|
||||
|
||||
def build_decoder(self):
|
||||
decoder_name = self.decoder_conf.pop("name", "transformer")
|
||||
model = None
|
||||
|
||||
if decoder_name == "matcha":
|
||||
from funasr.models.llm_asr.diffusion_models.flow_matching import CFM
|
||||
model = CFM(
|
||||
**self.decoder_conf,
|
||||
in_channels=self.output_size * 2, # 2 for noise_y and mu
|
||||
out_channel=self.output_size,
|
||||
spk_emb_dim=self.output_size
|
||||
)
|
||||
|
||||
self.decoder_conf["name"] = decoder_name
|
||||
|
||||
return model
|
||||
|
||||
def select_target_prompt(self, y, y_lengths):
|
||||
prompt_conf = self.prompt_conf
|
||||
prompt_list = []
|
||||
prompt_lengths = []
|
||||
for i, y_len in enumerate(y_lengths):
|
||||
prompt_len = random.randint(
|
||||
int(y_len * prompt_conf["prompt_with_range_ratio"][0]),
|
||||
int(y_len * prompt_conf["prompt_with_range_ratio"][1])
|
||||
)
|
||||
prompt_pos = random.randint(0, y_len - prompt_len)
|
||||
prompt_list.append(y[i, prompt_pos:prompt_pos+prompt_len])
|
||||
prompt_lengths.append(prompt_len)
|
||||
prompt = pad_list(prompt_list, 0.0)
|
||||
prompt_lengths = torch.tensor(prompt_lengths, dtype=torch.int64, device=y.device)
|
||||
|
||||
if "cgf_prob" in prompt_conf and prompt_conf["cgf_prob"] > 0:
|
||||
cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device) < prompt_conf["cgf_prob"]
|
||||
prompt = prompt * cgf_mask
|
||||
return prompt, prompt_lengths
|
||||
|
||||
def build_length_regulator(self):
|
||||
name = self.length_regulator_conf.pop("name", None)
|
||||
model = None
|
||||
if name == "upsampling":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator
|
||||
model = UpSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios"))
|
||||
elif name == "downsampling":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator
|
||||
model = DownSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios"))
|
||||
elif name == "interpolate":
|
||||
from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator
|
||||
model = InterpolateRegulator(self.output_size, **self.length_regulator_conf)
|
||||
else:
|
||||
raise ValueError(f"Unknown length_regulator {name}")
|
||||
|
||||
self.length_regulator_conf["name"] = name
|
||||
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
audio: torch.Tensor,
|
||||
audio_lengths: torch.Tensor,
|
||||
xvec: torch.Tensor,
|
||||
xvec_lengths: torch.Tensor,
|
||||
):
|
||||
batch_size = audio.shape[0]
|
||||
# for data parallel
|
||||
x = text[:, :text_lengths.max()]
|
||||
y = audio[:, :audio_lengths.max()]
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
mask = (x != -1).float().unsqueeze(-1)
|
||||
x = self.input_embedding(torch.clamp(x, min=0)) * mask
|
||||
|
||||
# random select a xvec from xvec matrix
|
||||
xvec_list = []
|
||||
for i, ilen in enumerate(xvec_lengths):
|
||||
idx = random.randint(0, ilen-1)
|
||||
while torch.any(~torch.isfinite(xvec[i, idx])):
|
||||
idx = random.randint(0, ilen - 1)
|
||||
xvec_list.append(xvec[i, idx])
|
||||
rand_xvec = torch.vstack(xvec_list)
|
||||
rand_xvec = self.xvec_proj(rand_xvec)
|
||||
|
||||
y, y_lengths = self.extract_feat(y, audio_lengths)
|
||||
h, h_lengths, _ = self.encoder(x, text_lengths)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
|
||||
if self.prompt_conf is not None:
|
||||
target_prompt = self.select_target_prompt(y, y_lengths)
|
||||
conditions = dict(
|
||||
xvec=rand_xvec,
|
||||
target_prompt=target_prompt,
|
||||
)
|
||||
else:
|
||||
conditions = None
|
||||
|
||||
mask = (~make_pad_mask(y_lengths)).to(y)
|
||||
# y, h in (B, T, D)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
y.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
rand_xvec,
|
||||
cond=conditions
|
||||
)
|
||||
|
||||
stats = dict(loss=torch.clone(loss.detach()))
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor):
|
||||
if self.output_type == 'mel':
|
||||
return self.mel_extractor(y, y_lengths)
|
||||
elif self.output_type == "codec":
|
||||
return self.quantizer_codebook(y.long(), y_lengths)
|
||||
else:
|
||||
return y, y_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, text_lens, xvec, xvec_lens, diff_steps=10, temperature=1.0, prompt=None):
|
||||
avg_xvec = torch.mean(xvec, dim=1)
|
||||
avg_xvec = self.xvec_proj(avg_xvec)
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.input_embedding(torch.clamp(text, min=0)) * mask
|
||||
h, h_lengths, _ = self.encoder(text, text_lens)
|
||||
h = self.encoder_proj(h)
|
||||
if self.output_type == "mel":
|
||||
coeff = ((self.mel_extractor.sampling_rate / self.mel_extractor.hop_size) /
|
||||
self.input_frame_rate)
|
||||
else:
|
||||
coeff = ((self.quantizer_codebook.sampling_rate / self.quantizer_codebook.hop_size) /
|
||||
self.input_frame_rate)
|
||||
y = torch.zeros([1, int(h.shape[1] * coeff), 80], device=text.device)
|
||||
y_lens = (text_lens * coeff).long()
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
|
||||
mask = (~make_pad_mask(y_lens)).to(y)
|
||||
feat = self.decoder.forward(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
n_timesteps=diff_steps,
|
||||
temperature=temperature,
|
||||
spks=avg_xvec,
|
||||
cond=None,
|
||||
)
|
||||
return feat
|
||||
|
||||
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(BaseDiffWithXvec):
|
||||
def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel",
|
||||
encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {},
|
||||
length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None,
|
||||
token_list: List = None, **kwargs):
|
||||
super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf,
|
||||
codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs)
|
||||
if self.prompt_conf is not None:
|
||||
self.masker = self.build_masker()
|
||||
self.cgf_prob = prompt_conf.get("cgf_prob", 0.0)
|
||||
self.prompt_dropout_rate = prompt_conf.get("prompt_dropout", 0.0)
|
||||
if self.prompt_dropout_rate > 0:
|
||||
self.prompt_dropout = nn.Dropout(self.prompt_dropout_rate)
|
||||
else:
|
||||
self.prompt_dropout = None
|
||||
self.only_mask_loss = kwargs.get("only_mask_loss", False)
|
||||
self.io_ratio = kwargs.get("io_ratio", None)
|
||||
if self.io_ratio == "auto":
|
||||
self.io_ratio = mel_feat_conf["sampling_rate"] / mel_feat_conf["hop_size"] / self.input_frame_rate
|
||||
self.first_package_conf = kwargs.get("first_package_conf", None)
|
||||
self.length_normalizer_ratio = kwargs.get("length_normalizer_ratio", None)
|
||||
|
||||
def build_masker(self):
|
||||
prompt_type = self.prompt_conf.get("prompt_type", "free")
|
||||
if prompt_type == "prefix":
|
||||
from funasr.models.specaug.mask_along_axis import PrefixMaskVariableMaxWidth
|
||||
masker = PrefixMaskVariableMaxWidth(
|
||||
mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
|
||||
)
|
||||
else:
|
||||
from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
||||
masker = MaskAlongAxisVariableMaxWidth(
|
||||
mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
|
||||
num_mask=1,
|
||||
)
|
||||
return masker
|
||||
|
||||
@staticmethod
|
||||
def norm_and_sample_xvec(xvec, xvec_lengths):
|
||||
xvec_list = []
|
||||
for i, ilen in enumerate(xvec_lengths):
|
||||
if ilen == 1:
|
||||
idx = 0
|
||||
else:
|
||||
idx = random.randint(0, ilen - 1)
|
||||
while torch.any(~torch.isfinite(xvec[i, idx])):
|
||||
idx = random.randint(0, ilen - 1)
|
||||
if torch.any(~torch.isfinite(xvec[i, idx])):
|
||||
to_add = torch.zeros_like(xvec[i, idx])
|
||||
else:
|
||||
to_add = xvec[i, idx]
|
||||
xvec_list.append(to_add)
|
||||
rand_xvec = torch.vstack(xvec_list)
|
||||
rand_xvec = F.normalize(rand_xvec, dim=1)
|
||||
|
||||
return rand_xvec
|
||||
|
||||
def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor):
|
||||
_, _, cond_mask = self.masker(y, y_lengths, return_mask=True)
|
||||
cond_mask = ~cond_mask
|
||||
|
||||
if self.cgf_prob > 0:
|
||||
cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device)
|
||||
cond_mask = cond_mask * (cgf_mask > self.cgf_prob)
|
||||
|
||||
return cond_mask
|
||||
|
||||
def build_decoder(self):
|
||||
decoder_name = self.decoder_conf.pop("name", "transformer")
|
||||
model = None
|
||||
|
||||
if decoder_name == "matcha":
|
||||
from funasr.models.llm_asr.diffusion_models.flow_matching import CFM
|
||||
model = CFM(
|
||||
**self.decoder_conf,
|
||||
)
|
||||
|
||||
self.decoder_conf["name"] = decoder_name
|
||||
|
||||
return model
|
||||
|
||||
def sample_first_package(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
audio: torch.Tensor,
|
||||
audio_lengths: torch.Tensor,
|
||||
):
|
||||
sample_rate = self.first_package_conf["sample_rate"]
|
||||
min_token_len, max_token_len = self.first_package_conf["token_len_range"]
|
||||
random_start = self.first_package_conf.get("random_start", False)
|
||||
bs = text.shape[0]
|
||||
sample_mask = torch.rand((bs, ), device=text_lengths.device) < sample_rate
|
||||
if random_start:
|
||||
text_list, text_lengths_list, audio_list, audio_lengths_list = [], [], [], []
|
||||
for i, total_len in enumerate(text_lengths):
|
||||
total_len = total_len.item()
|
||||
if sample_mask[i].item():
|
||||
if isinstance(min_token_len, float) and 0.0 < min_token_len <= 1.0:
|
||||
min_token_len = math.floor(min_token_len * total_len)
|
||||
if isinstance(max_token_len, float) and 0.0 < max_token_len <= 1.0:
|
||||
max_token_len = math.floor(max_token_len * total_len)
|
||||
if total_len > max_token_len > min_token_len:
|
||||
fp_len = random.randint(min_token_len, max_token_len)
|
||||
start = random.randint(0, total_len - fp_len)
|
||||
audio_st, audio_len = self.calc_target_len(torch.tensor(start)), self.calc_target_len(torch.tensor(fp_len))
|
||||
else:
|
||||
start, fp_len = 0, total_len
|
||||
audio_st, audio_len = 0, self.calc_target_len(fp_len)
|
||||
text_list.append(text[i, start: start+fp_len])
|
||||
text_lengths_list.append(fp_len)
|
||||
audio_list.append(audio[i, audio_st: audio_st+audio_len])
|
||||
audio_lengths_list.append(audio_list[-1].shape[0])
|
||||
else:
|
||||
text_list.append(text[i])
|
||||
text_lengths_list.append(text_lengths[i])
|
||||
audio_list.append(audio[i, :min(self.calc_target_len(text_lengths[i]), audio_lengths[i])])
|
||||
audio_lengths_list.append(audio_list[-1].shape[0])
|
||||
text = pad_list(text_list, pad_value=0.0).to(text)
|
||||
new_text_lengths = torch.tensor(text_lengths_list, dtype=torch.int64, device=text.device)
|
||||
audio = pad_list(audio_list, pad_value=0.0).to(audio)
|
||||
new_audio_lengths = torch.tensor(audio_lengths_list, dtype=torch.int64, device=audio.device)
|
||||
else:
|
||||
fp_token_len = torch.randint(min_token_len, max_token_len + 1, (bs,))
|
||||
fp_token_len = torch.minimum(fp_token_len.to(text_lengths), text_lengths)
|
||||
fp_audio_len = self.calc_target_len(fp_token_len)
|
||||
fp_audio_len = torch.minimum(fp_audio_len.to(audio_lengths), audio_lengths)
|
||||
new_text_lengths = torch.where(sample_mask, fp_token_len, text_lengths)
|
||||
new_audio_lengths = torch.where(sample_mask, fp_audio_len, audio_lengths)
|
||||
text = text * (~make_pad_mask(new_text_lengths, maxlen=text.shape[1]).unsqueeze(-1)).to(text.device)
|
||||
audio = audio * (~make_pad_mask(new_audio_lengths, maxlen=audio.shape[1]).unsqueeze(-1)).to(audio.device)
|
||||
|
||||
return text, new_text_lengths, audio, new_audio_lengths
|
||||
|
||||
@staticmethod
|
||||
def clip_both_side(y, y_lengths, raw_lengths):
|
||||
res_list = []
|
||||
new_length = []
|
||||
for i, (new_len, org_len) in enumerate(zip(y_lengths, raw_lengths)):
|
||||
if new_len >= org_len:
|
||||
res_list.append(y[i, :new_len])
|
||||
else:
|
||||
left = (org_len - new_len) // 2
|
||||
right = org_len - new_len - left
|
||||
res_list.append(y[i, left: org_len-right])
|
||||
|
||||
new_length.append(res_list[-1].shape[0])
|
||||
|
||||
new_length = torch.tensor(new_length).to(y_lengths)
|
||||
return pad_list(res_list, 0.0), new_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
audio: torch.Tensor,
|
||||
audio_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None,
|
||||
xvec_lengths: Optional[torch.Tensor] = None,
|
||||
):
|
||||
batch_size = audio.shape[0]
|
||||
# for data parallel
|
||||
with autocast(False):
|
||||
x = text[:, :text_lengths.max()]
|
||||
y = audio[:, :audio_lengths.max()]
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
mask = (x != -1).float().unsqueeze(-1)
|
||||
x = self.input_embedding(torch.clamp(x, min=0)) * mask
|
||||
|
||||
# random select a xvec from xvec matrix
|
||||
rand_xvec = None
|
||||
if xvec is not None:
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
|
||||
rand_xvec = self.xvec_proj(rand_xvec)
|
||||
|
||||
y, y_lengths = self.extract_feat(y, audio_lengths)
|
||||
if self.length_normalizer_ratio is not None:
|
||||
max_y_lengths = torch.round(text_lengths * self.length_normalizer_ratio).long()
|
||||
raw_lengths = y_lengths.clone()
|
||||
y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths)
|
||||
y, new_y_lengths = self.clip_both_side(y, y_lengths, raw_lengths)
|
||||
logging.info(f"normalized y_length from {raw_lengths.cpu().tolist()} to {y_lengths.cpu().tolist()} "
|
||||
f"new_y_length {new_y_lengths.cpu().tolist()}, with text_lengths {text_lengths.cpu().tolist()}")
|
||||
y = y[:, :new_y_lengths.max()]
|
||||
elif self.io_ratio is not None:
|
||||
hint_once(f"cut output with ratio {self.io_ratio}", "print_ratio", rank=0)
|
||||
max_y_lengths = (text_lengths * self.io_ratio + 3).long()
|
||||
if y_lengths.max() > max_y_lengths.max():
|
||||
logging.info(f"cut output with ratio {self.io_ratio} from {y_lengths.max()} to {max_y_lengths.max()}")
|
||||
y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths)
|
||||
y = y[:, :y_lengths.max()]
|
||||
|
||||
if self.first_package_conf is not None:
|
||||
x, text_lengths, y, y_lengths = self.sample_first_package(
|
||||
x, text_lengths, y, y_lengths
|
||||
)
|
||||
x = x[:, :text_lengths.max()]
|
||||
y = y[:, :y_lengths.max()]
|
||||
h, _, _ = self.encoder(x, text_lengths)
|
||||
h_lengths = text_lengths
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
|
||||
if self.prompt_conf is not None:
|
||||
cond_mask = self.select_target_prompt(y, y_lengths)
|
||||
if self.prompt_dropout is not None:
|
||||
hint_once(f"prompt dropout {self.prompt_dropout_rate}", "prompt dropout")
|
||||
y = self.prompt_dropout(y)
|
||||
conditions = (y * cond_mask).transpose(1, 2)
|
||||
else:
|
||||
cond_mask, conditions = None, None
|
||||
|
||||
stats = dict(
|
||||
batch_size=batch_size,
|
||||
in_lengths=text_lengths.max(),
|
||||
out_lengths=y_lengths.max(),
|
||||
)
|
||||
|
||||
mask = (~make_pad_mask(y_lengths)).to(y)
|
||||
# y, h in (B, T, D)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
y.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
rand_xvec,
|
||||
cond=conditions,
|
||||
reduction="none",
|
||||
)
|
||||
loss = loss.transpose(1, 2)
|
||||
all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1])
|
||||
if cond_mask is not None:
|
||||
masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask)
|
||||
else:
|
||||
masked_loss_mask = mask.unsqueeze(-1)
|
||||
masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1])
|
||||
stats["all_loss"] = all_loss.item()
|
||||
stats["masked_loss"] = masked_loss.item()
|
||||
|
||||
loss = masked_loss if self.only_mask_loss else all_loss
|
||||
|
||||
stats["loss"] = loss.item()
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@staticmethod
|
||||
def concat_prompt(prompt, prompt_lengths, text, text_lengths):
|
||||
xs_list, x_len_list = [], []
|
||||
for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)):
|
||||
xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0))
|
||||
x_len_list.append(_prompt_len + _text_len)
|
||||
|
||||
xs = pad_list(xs_list, pad_value=0.0)
|
||||
x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device)
|
||||
|
||||
return xs, x_lens
|
||||
|
||||
@staticmethod
|
||||
def remove_prompt(prompt, prompt_lengths, padded, padded_lengths):
|
||||
xs_list = []
|
||||
for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)):
|
||||
xs_list.append(padded[idx, _prompt_len: _x_len])
|
||||
|
||||
xs = pad_list(xs_list, pad_value=0.0)
|
||||
|
||||
return xs, padded_lengths - prompt_lengths
|
||||
|
||||
@staticmethod
|
||||
def norm_and_avg_xvec(xvec: torch.Tensor, xvec_lens: torch.Tensor):
|
||||
mask = torch.isfinite(xvec.norm(dim=-1, keepdim=True))
|
||||
norm_xvec = F.normalize(xvec, dim=-1) * mask
|
||||
avg_xvec = F.normalize(torch.sum(norm_xvec, dim=1) / mask.sum(), dim=-1)
|
||||
return avg_xvec
|
||||
|
||||
def calc_target_len(self, in_len):
|
||||
if self.input_frame_rate == 25 and self.output_type == "mel":
|
||||
if self.length_normalizer_ratio is not None:
|
||||
if isinstance(in_len, int):
|
||||
in_len = torch.tensor(in_len)
|
||||
ll = torch.round(in_len * self.length_normalizer_ratio)
|
||||
else:
|
||||
ll = (in_len * 4 + 4) * 160 + 400
|
||||
ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size
|
||||
if isinstance(in_len, int):
|
||||
ll = int(round(ll))
|
||||
else:
|
||||
ll = torch.round(ll).long()
|
||||
return ll
|
||||
if self.input_frame_rate == 50 and self.output_type == "mel":
|
||||
if self.length_normalizer_ratio is not None:
|
||||
if isinstance(in_len, int):
|
||||
in_len = torch.tensor(in_len)
|
||||
ll = torch.round(in_len * self.length_normalizer_ratio)
|
||||
else:
|
||||
ll = (in_len * 2 + 2) * 160 + 400
|
||||
ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size
|
||||
if isinstance(in_len, int):
|
||||
ll = int(round(ll))
|
||||
else:
|
||||
ll = torch.round(ll).long()
|
||||
return ll
|
||||
elif self.output_type == "codec":
|
||||
return in_len
|
||||
else:
|
||||
raise ValueError(f"Frame rate {self.input_frame_rate} has not implemented.")
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, text_lens,
|
||||
xvec=None, xvec_lens=None,
|
||||
diff_steps=10, temperature=1.0, prompt: dict = None, y_lens=None):
|
||||
rand_xvec = None
|
||||
if xvec is not None:
|
||||
if xvec.dim() == 2:
|
||||
xvec = xvec.unsqueeze(1)
|
||||
xvec_lens = torch.ones_like(xvec_lens)
|
||||
rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens)
|
||||
rand_xvec = self.xvec_proj(rand_xvec)
|
||||
|
||||
prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None))
|
||||
prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None))
|
||||
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
if prompt_text is not None:
|
||||
text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens)
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.input_embedding(torch.clamp(text, min=0)) * mask
|
||||
|
||||
h, h_lengths, _ = self.encoder(text, text_lens)
|
||||
h = self.encoder_proj(h)
|
||||
if y_lens is None:
|
||||
y_lens = self.calc_target_len(text_lens)
|
||||
y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device)
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
|
||||
|
||||
# get conditions
|
||||
if prompt_audio is not None:
|
||||
if prompt_audio.ndim == 2:
|
||||
prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens)
|
||||
for i, _len in enumerate(prompt_audio_lens):
|
||||
y[i, :_len] = prompt_audio[i]
|
||||
conds = y.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(y_lens)).to(y)
|
||||
feat = self.decoder.forward(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
n_timesteps=diff_steps,
|
||||
temperature=temperature,
|
||||
spks=rand_xvec,
|
||||
cond=conds,
|
||||
)
|
||||
|
||||
if prompt_text is not None and prompt_audio is not None:
|
||||
feat = feat.transpose(1, 2)
|
||||
feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device)
|
||||
feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens)
|
||||
feat = feat.transpose(1, 2)
|
||||
|
||||
# if prompt_audio is not None:
|
||||
# feat_rmq = torch.sqrt(torch.mean(torch.pow(feat, 2), dim=[1, 2], keepdim=True))
|
||||
# prompt_rmq = torch.sqrt(torch.mean(torch.pow(prompt_audio, 2), dim=[1, 2], keepdim=True))
|
||||
# feat = feat / feat_rmq * prompt_rmq
|
||||
|
||||
return feat
|
||||
|
||||
|
||||
class MaskedDiffTTS(MaskedDiffWithXvec):
|
||||
|
||||
def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel",
|
||||
encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {},
|
||||
length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None,
|
||||
token_list: List = None, **kwargs):
|
||||
super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf,
|
||||
codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs)
|
||||
self.length_loss_weight = kwargs.get("length_loss_weight", 0.0)
|
||||
if self.length_loss_weight > 0.0:
|
||||
self.length_predictor = nn.Linear(self.encoder.output_size(), 1)
|
||||
|
||||
def calc_target_len(self, enc_outs, enc_lens):
|
||||
text_durs = self.length_predictor(enc_outs)
|
||||
text_durs = torch.exp(text_durs)
|
||||
mask = ~make_pad_mask(enc_lens, xs=text_durs)
|
||||
utt_durs = (text_durs * mask).sum(dim=1).squeeze(-1)
|
||||
return utt_durs
|
||||
|
||||
def forward(self, text: torch.Tensor, text_lengths: torch.Tensor, audio: torch.Tensor, audio_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None):
|
||||
batch_size = audio.shape[0]
|
||||
# for data parallel
|
||||
x = text[:, :text_lengths.max()]
|
||||
y = audio[:, :audio_lengths.max()]
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
mask = (x != -1).float().unsqueeze(-1)
|
||||
x = self.input_embedding(torch.clamp(x, min=0)) * mask
|
||||
|
||||
# random select a xvec from xvec matrix
|
||||
rand_xvec = None
|
||||
if xvec is not None:
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths)
|
||||
rand_xvec = self.xvec_proj(rand_xvec)
|
||||
|
||||
y, y_lengths = self.extract_feat(y, audio_lengths)
|
||||
h, _, _ = self.encoder(x, text_lengths)
|
||||
h_lengths = text_lengths
|
||||
h_durs = self.calc_target_len(h, h_lengths)
|
||||
utt_dur_loss = self.length_loss_weight * F.l1_loss(h_durs, y_lengths)
|
||||
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths)
|
||||
if self.prompt_conf is not None:
|
||||
cond_mask = self.select_target_prompt(y, y_lengths)
|
||||
conditions = (y * cond_mask).transpose(1, 2)
|
||||
else:
|
||||
cond_mask, conditions = None, None
|
||||
|
||||
stats = dict(
|
||||
batch_size=batch_size,
|
||||
in_lengths=text_lengths.max(),
|
||||
out_lengths=y_lengths.max(),
|
||||
)
|
||||
|
||||
mask = (~make_pad_mask(y_lengths)).to(y)
|
||||
# y, h in (B, T, D)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
y.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
rand_xvec,
|
||||
cond=conditions,
|
||||
reduction="none",
|
||||
)
|
||||
loss = loss.transpose(1, 2)
|
||||
all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1])
|
||||
if cond_mask is not None:
|
||||
masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask)
|
||||
else:
|
||||
masked_loss_mask = mask.unsqueeze(-1)
|
||||
masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1])
|
||||
stats["all_loss"] = all_loss.item()
|
||||
stats["masked_loss"] = masked_loss.item()
|
||||
|
||||
loss = masked_loss if self.only_mask_loss else all_loss
|
||||
stats["mel_loss"] = loss.item()
|
||||
|
||||
loss = loss + utt_dur_loss
|
||||
stats["loss"] = loss.item()
|
||||
stats["utt_dur_loss"] = utt_dur_loss.item()
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def inference(self, text, text_lens, xvec=None, xvec_lens=None, diff_steps=10, temperature=1.0,
|
||||
prompt: dict = None):
|
||||
rand_xvec = None
|
||||
if xvec is not None:
|
||||
if xvec.dim() == 2:
|
||||
xvec = xvec.unsqueeze(1)
|
||||
xvec_lens = torch.ones_like(xvec_lens)
|
||||
rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens)
|
||||
rand_xvec = self.xvec_proj(rand_xvec)
|
||||
|
||||
prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None))
|
||||
prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None))
|
||||
|
||||
if self.vocab_size is not None and self.vocab_size > 0:
|
||||
if prompt_text is not None:
|
||||
text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens)
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text = self.input_embedding(torch.clamp(text, min=0)) * mask
|
||||
|
||||
h, _, _ = self.encoder(text, text_lens)
|
||||
h_lengths = text_lens
|
||||
y_lens = self.calc_target_len(h, h_lengths).round().long()
|
||||
y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens)
|
||||
|
||||
# get conditions
|
||||
if prompt_audio is not None:
|
||||
if prompt_audio.ndim == 2:
|
||||
prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens)
|
||||
for i, _len in enumerate(prompt_audio_lens):
|
||||
y[i, :_len] = prompt_audio[i]
|
||||
conds = y.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(y_lens)).to(y)
|
||||
feat = self.decoder.forward(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
n_timesteps=diff_steps,
|
||||
temperature=temperature,
|
||||
spks=rand_xvec,
|
||||
cond=conds,
|
||||
)
|
||||
|
||||
if prompt_text is not None and prompt_audio is not None:
|
||||
feat = feat.transpose(1, 2)
|
||||
feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device)
|
||||
feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens)
|
||||
feat = feat.transpose(1, 2)
|
||||
|
||||
return feat
|
||||
|
||||
477
funasr/models/llm_asr/hifigan.py
Normal file
477
funasr/models/llm_asr/hifigan.py
Normal file
@ -0,0 +1,477 @@
|
||||
# Copyright 2023 KaiHu
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""HIFI-GAN"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple, List, Union
|
||||
import typing as tp
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
import logging
|
||||
from funasr.utils.hinter import hint_once
|
||||
|
||||
|
||||
class Audio2Mel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
sampling_rate=22050,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=None,
|
||||
center=False,
|
||||
device='cuda',
|
||||
feat_type="power_log",
|
||||
):
|
||||
super().__init__()
|
||||
##############################################
|
||||
# FFT Parameters #
|
||||
##############################################
|
||||
window = torch.hann_window(win_length, device=device).float()
|
||||
mel_basis = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
|
||||
)
|
||||
mel_basis = torch.from_numpy(mel_basis).float().to(device)
|
||||
self.register_buffer("mel_basis", mel_basis)
|
||||
self.register_buffer("window", window)
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.center = center
|
||||
self.feat_type = feat_type
|
||||
|
||||
def forward(self, audioin):
|
||||
p = (self.n_fft - self.hop_length) // 2
|
||||
audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
|
||||
fft = torch.stft(
|
||||
audio,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
)
|
||||
if self.feat_type == "mag_log10":
|
||||
power_spec = torch.sqrt(torch.sum(torch.pow(fft, 2), dim=[-1]))
|
||||
mel_output = torch.matmul(self.mel_basis, power_spec)
|
||||
return torch.log10(torch.clamp(mel_output, min=1e-5))
|
||||
power_spec = torch.sum(torch.pow(fft, 2), dim=[-1])
|
||||
mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9))
|
||||
return self.spectral_normalize(mel_spec)
|
||||
|
||||
|
||||
@classmethod
|
||||
def spectral_normalize(cls, spec, C=1, clip_val=1e-5):
|
||||
output = cls.dynamic_range_compression(spec, C, clip_val)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5):
|
||||
output = cls.dynamic_range_decompression(spec, C, clip_val)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
@staticmethod
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
class HifiGan(nn.Module):
|
||||
"""HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator).
|
||||
NSF-HIFIGAN, HiFTNet Optional.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
frontend: torch.nn.Module = None,
|
||||
nsf_augmented: bool = False,
|
||||
f0_predictor: dict = None,
|
||||
generator: dict = None,
|
||||
discriminator: dict = None,
|
||||
target_sample_hz: int = 22_050,
|
||||
multi_mel_spectral_window: Union[Tuple, List] = tuple([1024]),
|
||||
multi_mel_spectral_hop: Union[Tuple, List] = tuple([256]),
|
||||
multi_mel_spectral_fft: Union[Tuple, List] = tuple([1024]),
|
||||
multi_mel_spectral_n_mels: Union[Tuple, List] = tuple([80]),
|
||||
mel_fmin: float = 0,
|
||||
mel_fmax: float = 8000,
|
||||
mel_fmax_for_loss: Optional[float] = None,
|
||||
multi_mel_spectral_recon_loss_weight: Union[Tuple[float], List[float]] = tuple([45]),
|
||||
adversarial_loss_weight: float = 1.0,
|
||||
feat_match_loss_weight: float = 2.0,
|
||||
tpr_loss_params: tp.Dict[str, tp.Any] = {"weight": 0.0, "tau": 0.04},
|
||||
mel_feat_type="power_log",
|
||||
):
|
||||
"""Initialize HifiGan model.
|
||||
Args:
|
||||
f0_predictor: f0 predictor (pretrained && frozen) for NSF-HIFIGAN, Optional.
|
||||
generator: hifigan generator
|
||||
discriminator: several discriminators, such as MSD, MPD, MRD
|
||||
multi_mel_spectral_window: stft window length
|
||||
multi_mel_spectral_hop: stft hop length
|
||||
multi_mel_spectral_fft: fft bins
|
||||
multi_mel_spectral_n_mels: Mel frequency bins
|
||||
mel_fmin: fmin for mel
|
||||
mel_fmax: fmax for mel
|
||||
mel_fmax_for_loss: fmax for multi mel spectral loss
|
||||
multi_mel_spectral_recon_loss_weight: the weight of frequency-domain reconstruction loss
|
||||
adversarial_loss_weight: the weight of adversarial loss from discriminator
|
||||
feat_match_loss_weight: the weight of intermediate feature loss from discriminator
|
||||
tpr_loss_params: the weight and tau of Truncated Pointwise Relativistic (TPR) loss from discriminator.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.decoder = self.build_decoder(generator)
|
||||
# Used by task and trainer
|
||||
self.gen_model_list = [self.decoder]
|
||||
|
||||
# nsf-hifigan or original hifigan
|
||||
self.nsf_augmented = nsf_augmented
|
||||
if nsf_augmented:
|
||||
assert f0_predictor is not None
|
||||
self.f0_predictor = self.build_f0_predictor(f0_predictor)
|
||||
# frozen
|
||||
for param in self.f0_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
self.gen_model_list.append(self.f0_predictor)
|
||||
|
||||
self.discriminator = self.build_discriminator(discriminator)
|
||||
|
||||
self.multi_mel_spec_transforms = nn.ModuleList()
|
||||
for n_fft, hop_len, win_len, n_mel in zip(multi_mel_spectral_fft, multi_mel_spectral_hop,
|
||||
multi_mel_spectral_window, multi_mel_spectral_n_mels):
|
||||
self.multi_mel_spec_transforms.append(
|
||||
Audio2Mel(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_len,
|
||||
win_length=win_len,
|
||||
sampling_rate=target_sample_hz,
|
||||
n_mel_channels=n_mel,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax_for_loss,
|
||||
center=False,
|
||||
)
|
||||
)
|
||||
|
||||
self.mel_spec_transform = Audio2Mel(
|
||||
n_fft=multi_mel_spectral_fft[0],
|
||||
hop_length=multi_mel_spectral_hop[0],
|
||||
win_length=multi_mel_spectral_window[0],
|
||||
sampling_rate=target_sample_hz,
|
||||
n_mel_channels=multi_mel_spectral_n_mels[0],
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
center=False,
|
||||
feat_type=mel_feat_type,
|
||||
)
|
||||
|
||||
# loss weights
|
||||
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
||||
self.adversarial_loss_weight = adversarial_loss_weight
|
||||
self.feat_match_loss_weight = feat_match_loss_weight
|
||||
self.tpr_loss_weight = tpr_loss_params.get("weight", 0.0)
|
||||
self.tpr_loss_tau = tpr_loss_params.get("tau", 0.04)
|
||||
self.register_buffer('zero', torch.tensor([0.]), persistent=False)
|
||||
self.gen_loss = 0
|
||||
self.sample_rate = target_sample_hz
|
||||
self.forward_step = 0
|
||||
|
||||
def build_decoder(self, conf):
|
||||
from funasr.models.llm_asr.hifigan_module.generator import HiFTGenerator
|
||||
return HiFTGenerator(**conf)
|
||||
|
||||
def build_f0_predictor(self, conf):
|
||||
from funasr.models.llm_asr.hifigan_module.nsf_utils import ConvRNNF0Predictor
|
||||
return ConvRNNF0Predictor(**conf)
|
||||
|
||||
def build_discriminator(self, conf):
|
||||
from funasr.models.llm_asr.hifigan_module.discriminator import MultipleDiscriminator
|
||||
return MultipleDiscriminator(**conf)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.nn.ModuleList(self.gen_model_list)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_generator: bool = True,
|
||||
batch: Dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Forward functions of generator and discriminator.
|
||||
|
||||
Args:
|
||||
forward_generator (bool): Whether to forward generator.
|
||||
batch (Dict[str, Tensor]): one batch including:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
- weight (Tensor): Weight tensor to summarize losses.
|
||||
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
|
||||
"""
|
||||
if forward_generator:
|
||||
if self.training:
|
||||
self.forward_step += 1
|
||||
return self._forward_generator(
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
)
|
||||
else:
|
||||
return self._forward_discriminator(
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Given a tensor `x`, returns the encoded representation for `x`
|
||||
"""
|
||||
assert x.dim() == 3
|
||||
_, channel, length = x.size()
|
||||
assert channel == 1
|
||||
mel = self.mel_spec_transform(x)
|
||||
return mel.squeeze()
|
||||
|
||||
@torch.no_grad()
|
||||
def _f0_pred(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Given a tensor `x`, return the predicted f0 for `x`, x in (B, C, T)
|
||||
"""
|
||||
if self.nsf_augmented:
|
||||
f0 = self.f0_predictor(x)
|
||||
if len(f0.shape) == 1:
|
||||
f0 = f0.unsqueeze(0)
|
||||
return f0
|
||||
else:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
def _decode(self, x: torch.Tensor, g: Union[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Decode the given representation into a waveform.
|
||||
|
||||
Args:
|
||||
x (Tensor): Speech representation tensor (B, C1, T)
|
||||
g (Tensor): Global conditional vector (B, C2, 1).
|
||||
"""
|
||||
if self.nsf_augmented:
|
||||
f0 = self._f0_pred(x)
|
||||
return self.decoder(x, f0, g)
|
||||
else:
|
||||
return self.decoder(x, g)
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
|
||||
"""
|
||||
# setup
|
||||
batch_size = speech.size(0)
|
||||
speech = speech.unsqueeze(1)
|
||||
orig_speech = speech.clone()
|
||||
|
||||
mel = self._encode(speech) # [B, C, T]
|
||||
recon_speech = self._decode(mel)[:, :, :speech.shape[-1]]
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
multi_mel_recon_loss = self.zero
|
||||
for lamda, mel_transform in zip(self.multi_mel_spectral_recon_loss_weight, self.multi_mel_spec_transforms):
|
||||
orig_mel, recon_mel = map(mel_transform, (orig_speech, recon_speech))
|
||||
multi_mel_recon_loss = multi_mel_recon_loss + lamda * F.l1_loss(orig_mel, recon_mel)
|
||||
|
||||
# calculate discriminator outputs
|
||||
# disc_outputs in the format [disc1_outputs, disc2_outputs, ...]
|
||||
# disc1_outputs includes [logits, intermediates]
|
||||
# intermediates includes [layer_1_intermediate, layer_2_intermediate, ...]
|
||||
fake_disc_outputs = self.discriminator(recon_speech)
|
||||
with torch.no_grad():
|
||||
# do not store discriminator gradient in generator turn
|
||||
real_disc_outputs = self.discriminator(orig_speech)
|
||||
|
||||
# calculate discriminator loss including adversarial, feat matching losses and tpr losses [Optional]
|
||||
adversarial_losses = []
|
||||
disc_feature_losses = []
|
||||
tpr_losses = []
|
||||
for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs):
|
||||
real_logits, real_intermediates = real_output
|
||||
fake_logits, fake_intermediates = fake_output
|
||||
adversarial_losses.append(torch.mean((1 - fake_logits)**2))
|
||||
for real_inter, fake_inter in zip(real_intermediates, fake_intermediates):
|
||||
_loss = torch.mean(torch.abs(real_inter.detach() - fake_inter))
|
||||
disc_feature_losses.append(_loss)
|
||||
|
||||
if self.tpr_loss_weight > 0.0:
|
||||
tau = self.tpr_loss_tau
|
||||
m_DG = torch.median((fake_logits - real_logits))
|
||||
L_rel = torch.mean((((fake_logits - real_logits) - m_DG) ** 2)[fake_logits < real_logits + m_DG])
|
||||
tpr_losses.append(tau - F.relu(tau - L_rel))
|
||||
|
||||
adversarial_loss = torch.stack(adversarial_losses).sum()
|
||||
feat_match_loss = torch.stack(disc_feature_losses).sum()
|
||||
tpr_loss = torch.zeros_like(adversarial_loss)
|
||||
if len(tpr_losses) > 0:
|
||||
tpr_loss = torch.stack(tpr_losses).sum()
|
||||
|
||||
# calculate losses
|
||||
gen_loss = multi_mel_recon_loss + \
|
||||
adversarial_loss * self.adversarial_loss_weight + \
|
||||
feat_match_loss * self.feat_match_loss_weight + \
|
||||
tpr_loss * self.tpr_loss_weight
|
||||
self.gen_loss += gen_loss.item()
|
||||
loss = gen_loss
|
||||
|
||||
stats = dict(
|
||||
generator_loss=loss.item(),
|
||||
generator_multi_mel_recon_loss=multi_mel_recon_loss.item(),
|
||||
generator_adv_loss=adversarial_loss.item(),
|
||||
generator_feat_match_loss=feat_match_loss.item(),
|
||||
generator_tpr_loss=tpr_loss.item(),
|
||||
batch_size=batch_size,
|
||||
batch_length=speech.shape[2],
|
||||
)
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"stats": stats,
|
||||
"weight": weight,
|
||||
"optim_idx": 0, # needed for trainer
|
||||
"real": orig_speech,
|
||||
"fake": recon_speech,
|
||||
}
|
||||
|
||||
def _forward_discriminator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform discriminator forward.
|
||||
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
"""
|
||||
# setup
|
||||
batch_size = speech.size(0)
|
||||
speech = speech.unsqueeze(1)
|
||||
orig_speech = speech.clone()
|
||||
|
||||
# A: calculate generator outputs
|
||||
with torch.no_grad():
|
||||
# do not store generator gradient in discriminator turn
|
||||
mel = self._encode(speech) # [B, C, T]
|
||||
recon_speech = self._decode(mel)[:, :, :speech.shape[-1]]
|
||||
|
||||
# B: calculate discriminator outputs
|
||||
real, fake = orig_speech.clone(), recon_speech.detach()
|
||||
real_disc_outputs = self.discriminator(real)
|
||||
fake_disc_outputs = self.discriminator(fake)
|
||||
|
||||
# C: calculate discriminator losses, tpr losses [Optional]
|
||||
disc_losses = []
|
||||
tpr_losses = []
|
||||
for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs):
|
||||
real_logits, real_intermediates = real_output
|
||||
fake_logits, fake_intermediates = fake_output
|
||||
one_disc_loss = torch.mean((1-real_logits) ** 2) + torch.mean((0 - fake_logits) ** 2)
|
||||
disc_losses.append(one_disc_loss)
|
||||
|
||||
if self.tpr_loss_weight > 0.0:
|
||||
tau = self.tpr_loss_tau
|
||||
m_DG = torch.median((real_logits - fake_logits))
|
||||
L_rel = torch.mean((((real_logits - fake_logits) - m_DG) ** 2)[real_logits < fake_logits + m_DG])
|
||||
tpr_losses.append(tau - F.relu(tau - L_rel))
|
||||
|
||||
disc_loss = torch.stack(disc_losses).sum()
|
||||
tpr_loss = torch.zeros_like(disc_loss)
|
||||
if len(tpr_losses) > 0:
|
||||
tpr_loss = torch.stack(tpr_losses).sum()
|
||||
|
||||
self.gen_loss = 0
|
||||
|
||||
loss = disc_loss + self.tpr_loss_weight * tpr_loss
|
||||
|
||||
stats = dict(
|
||||
discriminator_total_loss=loss.item(),
|
||||
discriminator_loss=disc_loss.item(),
|
||||
discriminator_tpr_loss=tpr_loss.item(),
|
||||
)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"stats": stats,
|
||||
"weight": weight,
|
||||
"optim_idx": 1, # needed for trainer
|
||||
"real": orig_speech,
|
||||
"fake": recon_speech,
|
||||
}
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run inference.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input representation, B x T x C
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
* recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav).
|
||||
|
||||
"""
|
||||
|
||||
recon_speech = self._decode(x.transpose(1, 2)).squeeze(1)
|
||||
retval = dict(
|
||||
recon_speech=recon_speech,
|
||||
)
|
||||
return retval
|
||||
|
||||
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return
|
||||
14
funasr/models/llm_asr/hifigan_module/__init__.py
Normal file
14
funasr/models/llm_asr/hifigan_module/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
from cosyvoice.modules.hifigan_module.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator
|
||||
from cosyvoice.modules.hifigan_module.discriminator import MultipleDiscriminator
|
||||
from cosyvoice.modules.hifigan_module.nsf_utils import ConvRNNF0Predictor
|
||||
120
funasr/models/llm_asr/hifigan_module/activations.py
Normal file
120
funasr/models/llm_asr/hifigan_module/activations.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
299
funasr/models/llm_asr/hifigan_module/discriminator.py
Normal file
299
funasr/models/llm_asr/hifigan_module/discriminator.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""hifigan based dicriminator implementation.
|
||||
|
||||
This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv2d, AvgPool1d, Conv1d
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
|
||||
from funasr.models.llm_asr.hifigan_module import get_padding
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3,
|
||||
use_spectral_norm=False, lrelu_slope=0.1):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.lrelu_slope = lrelu_slope
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
32, (kernel_size, 1), (stride, 1),
|
||||
padding=(get_padding(5, 1), 0))),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
32,
|
||||
128, (kernel_size, 1), (stride, 1),
|
||||
padding=(get_padding(5, 1), 0))),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
128,
|
||||
512, (kernel_size, 1), (stride, 1),
|
||||
padding=(get_padding(5, 1), 0))),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
512,
|
||||
1024, (kernel_size, 1), (stride, 1),
|
||||
padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels: int = 1,
|
||||
periods: tp.List[int] = [2, 3, 5, 7, 11]):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(p) for p in periods
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List: List of list of each discriminator outputs, which consists of each
|
||||
layer output tensors.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for f in self.discriminators:
|
||||
# outs += [f(x)]
|
||||
if return_intermediates:
|
||||
outs.append(f(x))
|
||||
else:
|
||||
outs.append(f(x)[0])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False, lrelu_slope=0.1):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
self.lrelu_slope = lrelu_slope
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self, in_channels: int = 1, nb_scales: int = 3):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
self.meanpools = nn.ModuleList(
|
||||
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List: List of list of each discriminator outputs, which consists of each
|
||||
layer output tensors.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for i, f in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
x = self.meanpools[i - 1](x)
|
||||
if return_intermediates:
|
||||
outs.append(f(x))
|
||||
else:
|
||||
outs.append(f(x)[0])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stft_params: tp.List[int],
|
||||
lrelu_slope: float = 0.1,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stft_params = stft_params
|
||||
self.lrelu_slope = lrelu_slope
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
||||
])
|
||||
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x):
|
||||
n_fft, hop_length, win_length = self.stft_params
|
||||
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
|
||||
x = x.squeeze(1)
|
||||
spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
||||
|
||||
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
||||
mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT]
|
||||
|
||||
return mag
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
x = self.spectrogram(x).unsqueeze(1)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
fft_sizes: tp.List[int] = [1024, 2048, 512],
|
||||
hop_sizes: tp.List[int] = [120, 240, 50],
|
||||
win_lengths: tp.List[int] = [600, 1200, 240],
|
||||
lrelu_slope: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList()
|
||||
|
||||
for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope))
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List: List of list of each discriminator outputs, which consists of each
|
||||
layer output tensors.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for f in self.discriminators:
|
||||
if return_intermediates:
|
||||
outs.append(f(x))
|
||||
else:
|
||||
outs.append(f(x)[0])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class MultipleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 1,
|
||||
disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.support_disc_choices = dict(
|
||||
mpd=MultiPeriodDiscriminator,
|
||||
msd=MultiScaleDiscriminator,
|
||||
mrd=MultiResolutionDiscriminator,
|
||||
)
|
||||
|
||||
self.discriminators = nn.ModuleList()
|
||||
self.discriminator_type_lst = []
|
||||
for args in disc_conf_list:
|
||||
assert "name" in args, "disc_conf must have `name` attr to specific disc type."
|
||||
disc_type = args.pop("name")
|
||||
assert disc_type in self.support_disc_choices, \
|
||||
"Unsupported discriminator type, only support {}".format(
|
||||
",".join(self.support_disc_choices.keys())
|
||||
)
|
||||
|
||||
disc_class = self.support_disc_choices[disc_type]
|
||||
one_disc = disc_class(in_channels=input_size, **args)
|
||||
self.discriminators.append(one_disc)
|
||||
# add back to the args for dump config.yaml
|
||||
args["name"] = disc_type
|
||||
self.discriminator_type_lst.append(disc_type)
|
||||
|
||||
def get_discriminator_type_lst(self) -> tp.List[str]:
|
||||
return self.discriminator_type_lst
|
||||
|
||||
def forward(self, x, return_intermediates=True):
|
||||
retval = []
|
||||
for disc in self.discriminators:
|
||||
out = disc(x, return_intermediates=return_intermediates)
|
||||
if isinstance(out, tuple):
|
||||
retval.append(out)
|
||||
elif isinstance(out, list):
|
||||
retval.extend(out)
|
||||
else:
|
||||
raise TypeError("The return value of discriminator must be tuple or list[tuple]")
|
||||
|
||||
return retval
|
||||
621
funasr/models/llm_asr/hifigan_module/generator.py
Normal file
621
funasr/models/llm_asr/hifigan_module/generator.py
Normal file
@ -0,0 +1,621 @@
|
||||
"""hifigan based generator implementation.
|
||||
|
||||
This code is modified from https://github.com/jik876/hifi-gan
|
||||
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
||||
https://github.com/NVIDIA/BigVGAN
|
||||
|
||||
"""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
from funasr.models.llm_asr.hifigan_module import get_padding, init_weights
|
||||
from funasr.models.llm_asr.hifigan_module.activations import Snake, SnakeBeta
|
||||
from funasr.models.llm_asr.hifigan_module.nsf_utils import SourceModule, SourceModuleHnNSF
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""Residual block module in HiFiGAN/BigVGAN."""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 512,
|
||||
kernel_size: int = 3,
|
||||
dilations: tp.List[int] = [1, 3, 5],
|
||||
use_additional_convs: bool = True,
|
||||
nonlinear_activation: str = "LeakyReLU",
|
||||
nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
||||
):
|
||||
super(ResBlock, self).__init__()
|
||||
self.use_additional_convs = use_additional_convs
|
||||
|
||||
self.convs1 = nn.ModuleList()
|
||||
if use_additional_convs:
|
||||
self.convs2 = nn.ModuleList()
|
||||
|
||||
for dilation in dilations:
|
||||
self.convs1.append(
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
padding=get_padding(kernel_size, dilation)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if use_additional_convs:
|
||||
self.convs2.append(
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.convs1.apply(init_weights)
|
||||
if use_additional_convs:
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
if nonlinear_activation == "LeakyReLU":
|
||||
self.activations1 = nn.ModuleList([
|
||||
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
|
||||
for _ in range(len(self.convs1))
|
||||
])
|
||||
if use_additional_convs:
|
||||
self.activations2 = nn.ModuleList([
|
||||
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
|
||||
for _ in range(len(self.convs2))
|
||||
])
|
||||
|
||||
elif nonlinear_activation == "Snake":
|
||||
self.activations1 = nn.ModuleList([
|
||||
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
||||
for _ in range(len(self.convs1))
|
||||
])
|
||||
if use_additional_convs:
|
||||
self.activations2 = nn.ModuleList([
|
||||
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
||||
for _ in range(len(self.convs2))
|
||||
])
|
||||
|
||||
elif nonlinear_activation == "SnakeBeta":
|
||||
self.activations1 = nn.ModuleList([
|
||||
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
||||
for _ in range(len(self.convs1))
|
||||
])
|
||||
if use_additional_convs:
|
||||
self.activations2 = nn.ModuleList([
|
||||
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
||||
for _ in range(len(self.convs2))
|
||||
])
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for idx in range(len(self.convs1)):
|
||||
xt = self.activations1[idx](x)
|
||||
xt = self.convs1[idx](xt)
|
||||
if self.use_additional_convs:
|
||||
xt = self.activations2[idx](xt)
|
||||
xt = self.convs2[idx](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for idx in range(len(self.convs1)):
|
||||
remove_weight_norm(self.convs1[idx])
|
||||
if self.use_additional_convs:
|
||||
remove_weight_norm(self.convs2[idx])
|
||||
|
||||
|
||||
class HifiGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
global_channels: int = -1,
|
||||
upsample_rates: tp.List[int] = [8, 8, 2, 2],
|
||||
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
|
||||
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
resblock_nonlinear_activation: str = "LeakyReLU",
|
||||
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
||||
use_additional_convs: bool = True,
|
||||
cond_in_each_up_layer: bool = False,
|
||||
lrelu_slope: float = 0.1,
|
||||
act_pre_each_up_layer: bool = True
|
||||
):
|
||||
super(HifiGenerator, self).__init__()
|
||||
|
||||
self.out_channels = 1
|
||||
self.global_channels = global_channels
|
||||
self.use_additional_convs = use_additional_convs
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.act_pre_each_up_layer = act_pre_each_up_layer
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
||||
)
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
||||
resblock_nonlinear_activation,
|
||||
resblock_nonlinear_activation_params))
|
||||
|
||||
if self.global_channels > 0:
|
||||
self.conv_global_cond = weight_norm(
|
||||
Conv1d(global_channels, base_channels, 1)
|
||||
)
|
||||
self.conv_global_cond.apply(init_weights)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conv_conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
self.conv_conds.append(weight_norm(
|
||||
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
||||
)
|
||||
self.conv_conds.apply(init_weights)
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def output_size(self):
|
||||
return self.out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# x in (B, in_channels, T), g in (B, global_channels, 1)
|
||||
x = self.conv_pre(x)
|
||||
if self.global_channels > 0 and g is not None:
|
||||
x = x + self.conv_global_cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if self.act_pre_each_up_layer:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.cond_in_each_up_layer and g is not None:
|
||||
x = x + self.conv_conds[i](g)
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
if self.global_channels > 0:
|
||||
remove_weight_norm(self.conv_global_cond)
|
||||
if self.cond_in_each_up_layer:
|
||||
for l in self.conv_conds:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class NsfHifiGenerator(nn.Module):
|
||||
"""
|
||||
Neural Source Filter + HifiGan
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
global_channels: int = -1,
|
||||
nb_harmonics: int = 7,
|
||||
sampling_rate: int = 22050,
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: tp.List[int] = [8, 8, 2, 2],
|
||||
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
|
||||
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
resblock_nonlinear_activation: str = "LeakyReLU",
|
||||
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
||||
use_additional_convs: bool = True,
|
||||
cond_in_each_up_layer: bool = False,
|
||||
lrelu_slope: float = 0.1,
|
||||
act_pre_each_up_layer: bool = True
|
||||
):
|
||||
super(NsfHifiGenerator, self).__init__()
|
||||
|
||||
self.out_channels = 1
|
||||
self.global_channels = global_channels
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.use_additional_convs = use_additional_convs
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.act_pre_each_up_layer = act_pre_each_up_layer
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1],
|
||||
sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold)
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
||||
)
|
||||
|
||||
# Up
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# Down
|
||||
self.source_downs = nn.ModuleList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, u in enumerate(downsample_cum_rates[::-1]):
|
||||
if (u == 1):
|
||||
self.source_downs.append(
|
||||
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1))
|
||||
)
|
||||
else:
|
||||
self.source_downs.append(
|
||||
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)))
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
||||
resblock_nonlinear_activation,
|
||||
resblock_nonlinear_activation_params))
|
||||
|
||||
if self.global_channels > 0:
|
||||
self.conv_global_cond = weight_norm(
|
||||
Conv1d(global_channels, base_channels, 1)
|
||||
)
|
||||
self.conv_global_cond.apply(init_weights)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conv_conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
self.conv_conds.append(weight_norm(
|
||||
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
||||
)
|
||||
self.conv_conds.apply(init_weights)
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def output_size(self):
|
||||
return self.out_channels
|
||||
|
||||
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
||||
return self.source_module(f0.unsqueeze(1))
|
||||
|
||||
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
|
||||
|
||||
s = self._f02source(f0)
|
||||
|
||||
x = self.conv_pre(x)
|
||||
if self.global_channels > 0 and g is not None:
|
||||
x = x + self.conv_global_cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if self.act_pre_each_up_layer:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.cond_in_each_up_layer and g is not None:
|
||||
x = x + self.conv_conds[i](g)
|
||||
|
||||
# fusion
|
||||
x = x + self.source_downs[i](s)
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
if self.global_channels > 0:
|
||||
remove_weight_norm(self.conv_global_cond)
|
||||
if self.cond_in_each_up_layer:
|
||||
for l in self.conv_conds:
|
||||
remove_weight_norm(l)
|
||||
self.source_module.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class HiFTGenerator(nn.Module):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
https://arxiv.org/abs/2309.09493
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
global_channels: int = -1,
|
||||
nb_harmonics: int = 8,
|
||||
sampling_rate: int = 22050,
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: tp.List[int] = [8, 8],
|
||||
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
||||
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
resblock_nonlinear_activation: str = "Snake",
|
||||
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
|
||||
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_nonlinear_activation: str = "Snake",
|
||||
source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
|
||||
use_additional_convs: bool = True,
|
||||
cond_in_each_up_layer: bool = False,
|
||||
lrelu_slope: float = 0.1,
|
||||
act_pre_each_up_layer: bool = True,
|
||||
audio_limit: float = 0.99,
|
||||
):
|
||||
super(HiFTGenerator, self).__init__()
|
||||
|
||||
self.out_channels = 1
|
||||
self.global_channels = global_channels
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.istft_params = istft_params
|
||||
self.use_additional_convs = use_additional_convs
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.act_pre_each_up_layer = act_pre_each_up_layer
|
||||
self.audio_limit = audio_limit
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold)
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
||||
)
|
||||
|
||||
# Up
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Down
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.source_resblocks = nn.ModuleList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
||||
source_resblock_dilation_sizes)):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
||||
)
|
||||
else:
|
||||
self.source_downs.append(
|
||||
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))
|
||||
)
|
||||
|
||||
self.source_resblocks.append(
|
||||
ResBlock(base_channels // (2 ** (i + 1)), k, d,
|
||||
use_additional_convs, source_resblock_nonlinear_activation,
|
||||
source_resblock_nonlinear_activation_params)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
||||
resblock_nonlinear_activation,
|
||||
resblock_nonlinear_activation_params))
|
||||
|
||||
if self.global_channels > 0:
|
||||
self.conv_global_cond = weight_norm(
|
||||
Conv1d(global_channels, base_channels, 1)
|
||||
)
|
||||
self.conv_global_cond.apply(init_weights)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conv_conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
self.conv_conds.append(weight_norm(
|
||||
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
||||
)
|
||||
self.conv_conds.apply(init_weights)
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||
window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||
self.register_buffer("stft_window", window)
|
||||
|
||||
def output_size(self):
|
||||
return self.out_channels
|
||||
|
||||
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
|
||||
har_source, _, _ = self.m_source(f0)
|
||||
return har_source.transpose(1, 2)
|
||||
|
||||
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
|
||||
|
||||
s = self._f02source(f0)
|
||||
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
x = self.conv_pre(x)
|
||||
if self.global_channels > 0 and g is not None:
|
||||
x = x + self.conv_global_cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if self.act_pre_each_up_layer:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.cond_in_each_up_layer and g is not None:
|
||||
x = x + self.conv_conds[i](g)
|
||||
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
# fusion
|
||||
si = self.source_downs[i](s_stft)
|
||||
si = self.source_resblocks[i](si)
|
||||
x = x + si
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
if self.global_channels > 0:
|
||||
remove_weight_norm(self.conv_global_cond)
|
||||
if self.cond_in_each_up_layer:
|
||||
for l in self.conv_conds:
|
||||
remove_weight_norm(l)
|
||||
self.source_module.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
remove_weight_norm(l)
|
||||
for l in self.source_resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def _stft(self, x):
|
||||
spec = torch.stft(
|
||||
x,
|
||||
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window,
|
||||
return_complex=True)
|
||||
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
||||
return spec[...,0], spec[...,1]
|
||||
|
||||
def _istft(self, magnitude, phase):
|
||||
magnitude = torch.clip(magnitude, max=1e2)
|
||||
real = magnitude * torch.cos(phase)
|
||||
img = magnitude * torch.sin(phase)
|
||||
inverse_transform = torch.istft(
|
||||
torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1),
|
||||
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window)
|
||||
|
||||
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||
93
funasr/models/llm_asr/hifigan_module/mel_spectrum.py
Normal file
93
funasr/models/llm_asr/hifigan_module/mel_spectrum.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
global mel_basis, hann_window
|
||||
spec = spectral_de_normalize_torch(spec)
|
||||
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
253
funasr/models/llm_asr/hifigan_module/nsf_utils.py
Normal file
253
funasr/models/llm_asr/hifigan_module/nsf_utils.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""
|
||||
Neural Source Filter based modules implementation.
|
||||
|
||||
Neural source-filter waveform models for statistical parametric speech synthesis
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
class SineGen(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, f0):
|
||||
"""
|
||||
:param f0: [B, 1, sample_len], Hz
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
|
||||
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
||||
for i in range(self.harmonic_num + 1):
|
||||
F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate
|
||||
|
||||
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
||||
u_dist = Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
||||
phase_vec[:, 0, :] = 0
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
||||
|
||||
# generate uv signal
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2))
|
||||
sine_wavs = sine_wavs.transpose(1,2)
|
||||
uv = uv.transpose(1,2)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class SourceModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
nb_harmonics: int,
|
||||
upsample_ratio: int,
|
||||
sampling_rate: int,
|
||||
alpha: float = 0.1,
|
||||
sigma: float = 0.003,
|
||||
voiced_threshold: float = 10
|
||||
):
|
||||
super(SourceModule, self).__init__()
|
||||
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.upsample_ratio = upsample_ratio
|
||||
self.sampling_rate = sampling_rate
|
||||
self.alpha = alpha
|
||||
self.sigma = sigma
|
||||
self.voiced_threshold = voiced_threshold
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
|
||||
nn.Tanh())
|
||||
|
||||
def f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def forward(self, f0):
|
||||
"""
|
||||
:param f0: [B, 1, frame_len], Hz
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
uv = self.f02uv(f0)
|
||||
f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
|
||||
F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device)
|
||||
for i in range(self.nb_harmonics + 1):
|
||||
F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate
|
||||
|
||||
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
||||
u_dist = Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device)
|
||||
phase_vec[:, 0, :] = 0
|
||||
|
||||
n_dist = Normal(loc=0., scale=self.sigma)
|
||||
noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device)
|
||||
|
||||
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
|
||||
e_unvoice = self.alpha / 3 / self.sigma * noise
|
||||
|
||||
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
|
||||
|
||||
return self.ffn(e)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.ffn[0])
|
||||
|
||||
|
||||
class ConvRNNF0Predictor(nn.Module):
|
||||
def __init__(self,
|
||||
num_class: int = 1,
|
||||
in_channels: int = 80,
|
||||
cond_channels: int = 512,
|
||||
use_cond_rnn: bool = True,
|
||||
bidirectional_rnn: bool = False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_class = num_class
|
||||
self.use_cond_rnn = use_cond_rnn
|
||||
|
||||
self.condnet = nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
if self.use_cond_rnn:
|
||||
self.rnn = nn.GRU(
|
||||
cond_channels,
|
||||
cond_channels // 2 if bidirectional_rnn else cond_channels,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional_rnn,
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.condnet(x)
|
||||
if self.use_cond_rnn:
|
||||
x, _ = self.rnn(x.transpose(1, 2))
|
||||
else:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return torch.abs(self.classifier(x).squeeze(-1))
|
||||
|
||||
|
||||
|
||||
93
funasr/models/llm_asr/mel_spectrum.py
Normal file
93
funasr/models/llm_asr/mel_spectrum.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
global mel_basis, hann_window
|
||||
spec = spectral_de_normalize_torch(spec)
|
||||
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os.path
|
||||
import torchaudio
|
||||
from typing import Union, Dict, List, Tuple, Optional
|
||||
|
||||
import time
|
||||
@ -1580,6 +1582,29 @@ class LLMASR5(nn.Module):
|
||||
reduction=False,
|
||||
)
|
||||
|
||||
mel_decoder_name = kwargs.get("mel_decoder", None)
|
||||
mel_decoder_conf = kwargs.get("mel_decoder_conf", None)
|
||||
self.mel_decoder = self.build_mel_decoder(name=mel_decoder_name, conf=mel_decoder_conf)
|
||||
vocoder_name = kwargs.get("vocoder", None)
|
||||
vocoder_conf = kwargs.get("vocoder_conf", None)
|
||||
self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf)
|
||||
|
||||
def build_mel_decoder(self, name: str, conf: dict):
|
||||
if name is None or conf is None:
|
||||
return None
|
||||
if name == "MaskedDiffWithXvec":
|
||||
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
|
||||
return MaskedDiffWithXvec(**conf)
|
||||
return None
|
||||
|
||||
def build_vocoder(self, name: str, conf: dict):
|
||||
if name is None or conf is None:
|
||||
return None
|
||||
if name == "HifiGAN":
|
||||
from funasr.models.llm_asr.hifigan import HifiGan
|
||||
return HifiGan(**conf)
|
||||
return None
|
||||
|
||||
def build_audio_decoder(self, name, conf):
|
||||
if name == "transformer":
|
||||
from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM
|
||||
@ -2205,7 +2230,16 @@ class LLMASR5(nn.Module):
|
||||
target_ids = generated_ids["sequences"]
|
||||
target_emb = self.llm.model.get_input_embeddings()(target_ids)
|
||||
if self.concat_emb_hidden:
|
||||
if not self.concat_emb_hidden_norm:
|
||||
hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1)
|
||||
hidden_states_select = self.audio_decoder_in_proj(hidden_states_select)
|
||||
else:
|
||||
outs = self.hidden_norm(hidden_states_select)
|
||||
outs = self.fusion_dropout(self.fusion_act(outs))
|
||||
# emb = model_outputs.hidden_states[0]
|
||||
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb)))
|
||||
outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1))
|
||||
hidden_states_select = self.fusion_act(self.fusion_norm(outs))
|
||||
|
||||
speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[
|
||||
:, :, 0
|
||||
@ -2221,12 +2255,18 @@ class LLMASR5(nn.Module):
|
||||
|
||||
loss = None
|
||||
|
||||
# synthesize waveforms
|
||||
spk_emb = kwargs.get("spk_emb", None)
|
||||
feat, wav = self.synthesize_waveform(speech_tokens, spk_emb, inputs_embeds.device)
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{0 + 1}best_recog"]
|
||||
|
||||
self.write_mel_wav(kwargs.get("output_dir"), feat, wav, key[0])
|
||||
|
||||
results = []
|
||||
response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
|
||||
result_i = {
|
||||
@ -2253,6 +2293,48 @@ class LLMASR5(nn.Module):
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def write_mel_wav(self, output_dir, feat, wav, key):
|
||||
out_dir = os.path.join(output_dir, "1best_recog", "mels")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if feat is not None:
|
||||
feat = feat.cpu().numpy()[0]
|
||||
np.save(os.path.join(out_dir, f"{key}.npy"), feat)
|
||||
|
||||
out_dir = os.path.join(output_dir, "1best_recog", "wavs")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if wav is not None:
|
||||
path = os.path.join(out_dir, f"{key}.wav")
|
||||
torchaudio.save(
|
||||
path, wav[0], sample_rate=self.vocoder.sample_rate,
|
||||
encoding='PCM_S', bits_per_sample=16
|
||||
)
|
||||
|
||||
def synthesize_waveform(self, speech_tokens, spk_emb, device):
|
||||
mel_feat, wav = None, None
|
||||
if self.mel_decoder is not None and spk_emb is not None:
|
||||
# mel_feat in BxCxT
|
||||
mel_feat = self.token2mel(speech_tokens, spk_emb, device)
|
||||
if self.vocoder is not None:
|
||||
wav = self.vocoder.inference(mel_feat.transpose(1, 2))
|
||||
|
||||
return mel_feat, wav
|
||||
|
||||
def token2mel(self, tokens: torch.Tensor, xvec: torch.Tensor, device: torch.device):
|
||||
xvec = torch.tensor(xvec).to(device).unsqueeze(0)
|
||||
xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64)
|
||||
token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64)
|
||||
feat = self.mel_decoder.inference(
|
||||
tokens, token_lens,
|
||||
xvec, xvec_lens,
|
||||
diff_steps=10,
|
||||
temperature=1.0,
|
||||
prompt=dict(
|
||||
prompt_text=(None, None),
|
||||
prompt_audio=(None, None)
|
||||
)
|
||||
)
|
||||
return feat
|
||||
|
||||
def audio_decode(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
@ -2263,9 +2345,8 @@ class LLMASR5(nn.Module):
|
||||
decoding_length=None,
|
||||
):
|
||||
# 1. encode text
|
||||
text = self.audio_decoder_in_proj(text)
|
||||
# text = self.audio_decoder_in_proj(text)
|
||||
device = text.device
|
||||
out_tokens = []
|
||||
sos_eos_emb = self.audio_decoder_embedding(
|
||||
torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)
|
||||
)
|
||||
@ -2273,30 +2354,18 @@ class LLMASR5(nn.Module):
|
||||
torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)
|
||||
)
|
||||
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
|
||||
state, cfg_state = None, None
|
||||
seq_input = torch.zeros(
|
||||
[1, prompt.shape[1] + max_length, prompt.shape[2]],
|
||||
dtype=torch.float32, device=device
|
||||
)
|
||||
seq_input[:, :prompt.shape[1], :] = prompt
|
||||
out_tokens = torch.zeros([1, max_length, 1], device=device)
|
||||
out_token_len = 0
|
||||
prompt_len = prompt.shape[1]
|
||||
state, hit_eos = None, False
|
||||
for i in range(max_length):
|
||||
if len(out_tokens) > 0:
|
||||
codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||
codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device)
|
||||
# if any quantizer output is eos
|
||||
if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)):
|
||||
break
|
||||
seq_input, _ = self.prepare_audio_decoder_io(
|
||||
text, text_lengths, codec_prompt, codec_lengths, need_targets=False
|
||||
)
|
||||
else:
|
||||
seq_input, _ = self.prepare_audio_decoder_io(
|
||||
text, text_lengths, None, None, need_targets=False
|
||||
)
|
||||
|
||||
# use state for speedup
|
||||
pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0])
|
||||
if infer_cfg_ratio is not None:
|
||||
cond_len = prompt[0].shape[0]
|
||||
cfg_pred, (cfg_state, _) = self.audio_decoder.score(
|
||||
seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :]
|
||||
)
|
||||
pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred
|
||||
|
||||
# sampling all `nq` token ids
|
||||
pred = pred.reshape(self.predict_nq, -1)
|
||||
@ -2304,49 +2373,44 @@ class LLMASR5(nn.Module):
|
||||
pred = torch.log_softmax(pred, dim=-1)
|
||||
if min_length is not None and i < min_length:
|
||||
pred[:, self.codebook_size + self.ad_sos_eos] = float(np.finfo(np.float32).min)
|
||||
top_ids = []
|
||||
for k in range(self.predict_nq):
|
||||
top_ids.append(self.ras_sampling(pred[k], out_tokens)[0].item())
|
||||
out_tokens.append(top_ids)
|
||||
top_ids = self.ras_sampling(pred[0], out_tokens[0])
|
||||
out_tokens[0, out_token_len, 0] = top_ids[0]
|
||||
seq_input[0, prompt_len + out_token_len, :] = self.codec_embedder(top_ids)[0]
|
||||
out_token_len += 1
|
||||
|
||||
# remove eos token
|
||||
hit_eos = False
|
||||
if torch.any(
|
||||
torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos
|
||||
):
|
||||
if torch.any(out_tokens[:, out_token_len - 1] == (self.codebook_size + self.ad_sos_eos)):
|
||||
hit_eos = True
|
||||
out_tokens = out_tokens[:-1]
|
||||
out_tokens = out_tokens[:, :out_token_len, :]
|
||||
break
|
||||
|
||||
if decoding_length is None:
|
||||
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||
return out_tokens
|
||||
else:
|
||||
return torch.tensor([out_tokens], dtype=torch.int64, device=device), hit_eos
|
||||
return out_tokens, hit_eos
|
||||
|
||||
# Repetition Aware Sampling in VALL-E 2
|
||||
def ras_sampling(
|
||||
self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1
|
||||
):
|
||||
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
|
||||
rep_num = torch.sum(decoded_tokens[-win_size:] == top_ids).item()
|
||||
if rep_num >= win_size * tau_r:
|
||||
top_ids = self.random_sampling(weighted_scores)
|
||||
|
||||
return top_ids
|
||||
|
||||
def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25):
|
||||
prob, indices = [], []
|
||||
cum_prob = 0.0
|
||||
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||
i = len(sorted_idx)
|
||||
for i in range(len(sorted_idx)):
|
||||
# sampling both top-p and numbers.
|
||||
if cum_prob < top_p and len(prob) < top_k:
|
||||
if cum_prob < top_p and i < top_k:
|
||||
cum_prob += sorted_value[i]
|
||||
prob.append(sorted_value[i])
|
||||
indices.append(sorted_idx[i])
|
||||
else:
|
||||
break
|
||||
prob = torch.tensor(prob).to(weighted_scores)
|
||||
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||
prob = sorted_value[:i]
|
||||
indices = sorted_idx[:i]
|
||||
sampling_ids = prob.multinomial(1, replacement=True)
|
||||
top_ids = indices[sampling_ids]
|
||||
return top_ids
|
||||
|
||||
@ -334,3 +334,45 @@ class MaskAlongAxisLFR(torch.nn.Module):
|
||||
replace_with_zero=self.replace_with_zero,
|
||||
lfr_rate=self.lfr_rate,
|
||||
)
|
||||
|
||||
|
||||
class PrefixMaskVariableMaxWidth(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
||||
replace_value: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mask_width_ratio_range = mask_width_ratio_range
|
||||
self.replace_value = replace_value
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False):
|
||||
bb, tt, _ = spec.shape
|
||||
|
||||
mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device)
|
||||
mask_width_range = (mask_width_ratio_range * tt).long()
|
||||
mask_length = torch.randint(
|
||||
mask_width_range[0],
|
||||
mask_width_range[1],
|
||||
(bb, 1),
|
||||
device=spec.device,
|
||||
).unsqueeze(2)
|
||||
|
||||
# mask_pos: (B, num_mask, 1)
|
||||
mask_pos = tt - mask_length
|
||||
|
||||
aran = torch.arange(tt, device=spec.device)[None, None, :]
|
||||
# mask: (Batch, num_mask, L)
|
||||
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
||||
# Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1)
|
||||
mask = mask.any(dim=1).unsqueeze(2)
|
||||
|
||||
spec = spec.masked_fill(mask, self.replace_value)
|
||||
if return_mask:
|
||||
return spec, spec_lengths, mask
|
||||
return spec, spec_lengths
|
||||
|
||||
13
funasr/utils/hinter.py
Normal file
13
funasr/utils/hinter.py
Normal file
@ -0,0 +1,13 @@
|
||||
import sys
|
||||
import torch.distributed
|
||||
import logging
|
||||
|
||||
HINTED = set()
|
||||
|
||||
|
||||
def hint_once(content, uid, rank=None):
|
||||
if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
|
||||
if uid not in HINTED:
|
||||
logging.info(content)
|
||||
HINTED.add(uid)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user