diff --git a/funasr/models/llm_asr/conformer_encoder.py b/funasr/models/llm_asr/conformer_encoder.py new file mode 100644 index 000000000..d78db1aeb --- /dev/null +++ b/funasr/models/llm_asr/conformer_encoder.py @@ -0,0 +1,628 @@ +# Copyright 2020 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Conformer encoder definition.""" +import logging +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import torch +from torch import nn +from funasr.models.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 + LegacyRelPositionMultiHeadedAttention, # noqa: H301 +) +from funasr.models.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 + LegacyRelPositionalEncoding, # noqa: H301 +) +from funasr.models.transformer.layer_norm import LayerNorm +from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear +from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d +from funasr.models.transformer.utils.nets_utils import get_activation +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.transformer.utils.mask import subsequent_mask +from funasr.models.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.models.transformer.utils.repeat import repeat +from funasr.models.transformer.utils.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8, TooShortUttError, + check_short_utt, Conv2dSubsamplingPad +) + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward(self, x): + """Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + if pos_emb is not None: + return (x, pos_emb), mask + return x, mask + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( + self.feed_forward_macaron(x) + ) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( + self.feed_forward(x) + ) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask + + +class ConformerEncoder(nn.Module): + """Conformer encoder module. + + Args: + input_size (int): Input dimension. + output_size (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + If True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + If False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + rel_pos_type (str): Whether to use the latest relative positional encoding or + the legacy one. The legacy relative positional encoding will be deprecated + in the future. More Details can be found in + https://github.com/espnet/espnet/pull/2816. + encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. + encoder_attn_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + cnn_module_kernel: int = 31, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + stochastic_depth_rate: Union[float, List[float]] = 0.0, + causal: bool = False, + skip: bool = False, + channel_first: bool = False, + ): + super().__init__() + self._output_size = output_size + self.causal = causal + self.skip = skip + self.channel_first = channel_first + + if rel_pos_type == "legacy": + if pos_enc_layer_type == "rel_pos": + pos_enc_layer_type = "legacy_rel_pos" + if selfattention_layer_type == "rel_selfattn": + selfattention_layer_type = "legacy_rel_selfattn" + elif rel_pos_type == "latest": + assert selfattention_layer_type != "legacy_rel_selfattn" + assert pos_enc_layer_type != "legacy_rel_pos" + else: + raise ValueError("unknown rel_pos_type: " + rel_pos_type) + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + assert selfattention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2dpad": + self.embed = Conv2dSubsamplingPad( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " + f"should be equal to num_blocks ({num_blocks})" + ) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate[lnum], + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor = None, + prev_states: torch.Tensor = None, + ctc = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + + """ + raw_input = xs_pad + if self.channel_first: + xs_pad = xs_pad.permute(0, 2, 1) + + if ilens is not None: + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + else: + masks = torch.ones(xs_pad.shape[0], 1, xs_pad.shape[1], + dtype=torch.bool, device=xs_pad.device) + if self.causal: + causal_mask = subsequent_mask( + xs_pad.shape[1], device=xs_pad.device, dtype=masks.dtype + ).unsqueeze(0) + masks = masks & causal_mask + + if ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + or isinstance(self.embed, Conv2dSubsamplingPad) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + + if isinstance(xs_pad, tuple): + x, pos_emb = xs_pad + x = x + self.conditioning_layer(ctc_out) + xs_pad = (x, pos_emb) + else: + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + if self.channel_first: + xs_pad = xs_pad.permute(0, 2, 1) + + if self.skip: + xs_pad = xs_pad + raw_input + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + + if ilens is not None: + return xs_pad, olens, None + else: + return xs_pad diff --git a/funasr/models/llm_asr/diffusion_models/__init__.py b/funasr/models/llm_asr/diffusion_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/llm_asr/diffusion_models/flow_matching.py b/funasr/models/llm_asr/diffusion_models/flow_matching.py new file mode 100644 index 000000000..f30f8ea4d --- /dev/null +++ b/funasr/models/llm_asr/diffusion_models/flow_matching.py @@ -0,0 +1,178 @@ +from abc import ABC +import torch +import torch.nn.functional as F +from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, ConditionalDecoder) +import logging +from funasr.utils.hinter import hint_once + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats: int, + cfm_params: dict, + n_spks: int = 1, + spk_emb_dim: int = 128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.get("solver", "euler") + self.sigma_min = cfm_params.get("sigma_min", 1e-4) + + self.estimator = None + self.t_scheduler = cfm_params.get("t_scheduler", "linear") + self.training_cfg_rate = cfm_params.get("training_cfg_rate", 0.0) + self.inference_cfg_rate = cfm_params.get("inference_cfg_rate", 0.0) + self.reg_loss_type = cfm_params.get("reg_loss_type", "l2") + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + steps = 1 + while steps <= len(t_span) - 1: + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + # Classifier-Free Guidance inference introduced in VoiceBox + if self.inference_cfg_rate > 0: + cfg_dphi_dt = self.estimator( + x, mask, + torch.zeros_like(mu), t, + torch.zeros_like(spks) if spks is not None else None, + torch.zeros_like(cond) + ) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - + self.inference_cfg_rate * cfg_dphi_dt) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if steps < len(t_span) - 1: + dt = t_span[steps + 1] - t + steps += 1 + + return sol[-1] + + def calc_reg_loss(self, prediction, target, loss_mask): + if self.reg_loss_type == 'l1': + hint_once("use l1 loss to train CFM", "CFM_LOSS_L1") + l1_loss = F.l1_loss(prediction, target, reduction="none") + l1_loss = l1_loss * loss_mask + return l1_loss + elif self.reg_loss_type == 'l2': + hint_once("use l2 loss to train CFM", "CFM_LOSS_L2") + l2_loss = F.mse_loss(prediction, target, reduction="none") + l2_loss = l2_loss * loss_mask + return l2_loss + else: + hint_once("use l1+l2 loss to train CFM", "CFM_LOSS_L1_L2") + l1_loss = F.l1_loss(prediction, target, reduction="none") + l1_loss = l1_loss * loss_mask + l2_loss = 0.5 * F.mse_loss(prediction, target, reduction="none") + l2_loss = l2_loss * loss_mask + return l1_loss * 0.5 + l2_loss * 0.5 + + def compute_loss(self, x1, mask, mu, spks=None, cond=None, reduction='none'): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + if self.training_cfg_rate > 0: + cfg_mask = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) > self.training_cfg_rate + mu = mu * cfg_mask + if spks is not None: + spks = spks * cfg_mask.squeeze(-1) + if cond is not None: + cond = cond * cfg_mask + + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + loss = self.calc_reg_loss(pred, u, mask) + if reduction == "mean": + loss = loss.sum() / (torch.sum(mask) * u.shape[1]) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, + n_spks=1, spk_emb_dim=64, decoder_name="Decoder"): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + # Just change the architecture of the estimator here + if decoder_name == "Decoder": + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) + else: + self.estimator = ConditionalDecoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/funasr/models/llm_asr/diffusion_models/length_regulator.py b/funasr/models/llm_asr/diffusion_models/length_regulator.py new file mode 100644 index 000000000..733d37d7a --- /dev/null +++ b/funasr/models/llm_asr/diffusion_models/length_regulator.py @@ -0,0 +1,219 @@ +import torch +from typing import Tuple +import torch.nn as nn +from torch.nn import functional as F +from funasr.models.llm_asr.diffusion_models.matcha_decoder import Upsample1D, Downsample1D +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from einops import repeat, pack +import logging + + +class UpSamplingRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for ratio in sampling_ratios: + if ratio > 1: + module = Upsample1D(channels=channels, channel_first=False) + else: + module = nn.Linear(channels, channels) + norm = nn.LayerNorm(channels) + act = nn.LeakyReLU() + model.extend([module, norm, act]) + model.append( + nn.Linear(channels, out_channels) + ) + self.model = nn.Sequential(*model) + + def forward(self, x, xlens, y=None, y_lens=None, cond=None): + # x, out, y in (B, T, D) + out = self.model(x) + out = out[:, :y.shape[1]] + olens = y_lens + + return out, olens + + +class DownSamplingRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for ratio in sampling_ratios: + if ratio > 1: + module = Downsample1D(dim=channels, channel_first=False, padding=2) + else: + module = nn.Linear(channels, channels) + norm = nn.LayerNorm(channels) + act = nn.LeakyReLU() + model.extend([module, norm, act]) + + model.append( + nn.Linear(channels, out_channels) + ) + self.model = nn.Sequential(*model) + + def forward(self, x, xlens, y=None, y_lens=None, cond=None): + # x, out, y in (B, T, D) + out = self.model(x) + out = out[:, :y.shape[1]] + olens = y_lens + + return out, olens + + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + mode="nearest", + align_corners=False, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + self.mode = mode + self.align_corners = align_corners + + def forward(self, x, xlens, y=None, ylens=None, cond=None): + # x in (B, T, D) + mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) + align_corners_opt = {} + if self.mode in ["linear", "bilinear","bicubic", "trilinear"]: + align_corners_opt = dict(align_corners=self.align_corners) + x = F.interpolate(x.transpose(1, 2).contiguous(), size=y.shape[1], + mode=self.mode, **align_corners_opt) + out = self.model(x).transpose(1, 2).contiguous() + olens = ylens + + return out * mask, olens + + +class UpsamplingBlock(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + stride=2, + groups=1, + channel_first=False, + ): + super().__init__() + self.channel_first = channel_first + self.stride = stride + + self.up_conv = nn.ConvTranspose1d(in_channels, channels, stride * 2, stride, 1) + self.block1 = torch.nn.Sequential( + torch.nn.Conv1d(channels, channels, 3, padding=1), + torch.nn.GroupNorm(groups, channels), + nn.Mish(), + ) + self.block2 = torch.nn.Sequential( + torch.nn.Conv1d(channels, channels, 3, padding=1), + torch.nn.GroupNorm(groups, channels), + nn.Mish(), + ) + self.res_conv = torch.nn.Conv1d(channels, channels, 1) + + def forward(self, x, ilens): + if not self.channel_first: + x = x.transpose(1, 2) + + olens = ilens * self.stride + o_masks = (~make_pad_mask(olens))[:, None, :].to(x) + res = out = self.up_conv(x) * o_masks + + out = self.block1(out) * o_masks + out + out = self.block2(out) * o_masks + out + out = out + self.res_conv(res) * o_masks + + if not self.channel_first: + out = out.transpose(1, 2) + + return out, olens + + +class RepeatLengthRegulator(torch.nn.Module): + """Repeat Length regulator module for feed-forward Transformer. + + This is a module of length regulator described in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The length regulator expands char or + phoneme-level embedding features to frame-level by repeating each + feature based on the corresponding predicted durations. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, pad_value=0.0): + """Initilize length regulator module. + + Args: + pad_value (float, optional): Value used for padding. + + """ + super().__init__() + self.pad_value = pad_value + + def forward(self, xs, ds, alpha=1.0): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). + ds (LongTensor): Batch of durations of each frame (B, T). + alpha (float, optional): Alpha value to control speed of speech. + + Returns: + Tensor: replicated input tensor based on durations (B, T*, D). + + """ + if alpha != 1.0: + assert alpha > 0 + ds = torch.round(ds.float() * alpha).long() + + if ds.sum() == 0: + logging.warning( + "predicted durations includes all 0 sequences. " + "fill the first element with 1." + ) + # NOTE(kan-bayashi): This case must not be happened in teacher forcing. + # It will be happened in inference with a bad duration predictor. + # So we do not need to care the padded sequence case here. + ds[ds.sum(dim=1).eq(0)] = 1 + + repeat = [torch.repeat_interleave(x, d, dim=0) for x, d in zip(xs, ds)] + return pad_list(repeat, self.pad_value) diff --git a/funasr/models/llm_asr/diffusion_models/matcha_decoder.py b/funasr/models/llm_asr/diffusion_models/matcha_decoder.py new file mode 100644 index 000000000..a1a351136 --- /dev/null +++ b/funasr/models/llm_asr/diffusion_models/matcha_decoder.py @@ -0,0 +1,844 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.llm_asr.diffusion_models.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim, channel_first=True, padding=1): + super().__init__() + self.channel_first = channel_first + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, padding) + + def forward(self, x): + if not self.channel_first: + x = x.transpose(1, 2).contiguous() + + out = self.conv(x) + + if not self.channel_first: + out = out.transpose(1, 2).contiguous() + return out + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, + out_channels=None, name="conv", channel_first=True, stride=2): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.channel_first = channel_first + self.stride = stride + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, stride*2, stride, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + if not self.channel_first: + inputs = inputs.transpose(1, 2).contiguous() + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + if not self.channel_first: + outputs = outputs.transpose(1, 2).contiguous() + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + **kwargs + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class TransformerDecoderWrapper(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + dropout: float = 0.1, + activation_fn: str = "snakebeta", + cond_dim: int = 80, + concat_after: bool = False, + ): + super().__init__() + attn_dim = num_attention_heads * attention_head_dim + self.input_proj = torch.nn.Linear(dim, attn_dim) + self.cond_proj = torch.nn.Linear(cond_dim, attn_dim) + self.output_proj = torch.nn.Linear(attn_dim, dim) + from funasr.models.transformer.decoder import ( + DecoderLayer, MultiHeadedAttention, PositionwiseFeedForward + ) + self.decoder_layer = DecoderLayer( + attn_dim, + MultiHeadedAttention( + num_attention_heads, attn_dim, dropout + ), + MultiHeadedAttention( + num_attention_heads, attn_dim, dropout + ), + PositionwiseFeedForward(attn_dim, attn_dim * 4, dropout), + dropout, + normalize_before=True, + concat_after=concat_after, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + timestep: torch.Tensor=None, + prompt: torch.Tensor = None, + prompt_lengths: torch.Tensor = None, + **kwargs + ): + # make masks and forward attention layer + x_mask = attention_mask[:, :1].transpose(1, 2).contiguous() + prompt_mask = (~make_pad_mask(prompt_lengths)[:, :, None]).to(hidden_states.device) + x = self.input_proj(hidden_states) * x_mask + prompt = self.cond_proj(prompt) * prompt_mask + x = self.decoder_layer(x, x_mask.transpose(1, 2).contiguous(), prompt, prompt_mask.transpose(1, 2).contiguous())[0] + x = self.output_proj(x) * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + conditions: dict = None, + concat_after: bool = False, + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.conditions = conditions + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_block_type = down_block_type + self.mid_block_type = mid_block_type + self.up_block_type = up_block_type + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + if conditions is not None and "xvec" in conditions: + input_channel = input_channel + conditions["xvec"]["dim"] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + concat_after=concat_after, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + if conditions is not None and "xvec" in conditions: + input_channel = input_channel + conditions["xvec"]["dim"] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + concat_after=concat_after, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + if conditions is not None and "xvec" in conditions: + input_channel = input_channel + conditions["xvec"]["dim"] + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + concat_after=concat_after, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + elif block_type == "transformer_decoder": + block = TransformerDecoderWrapper( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + cond_dim=self.conditions["prompt"]["dim"], + concat_after=kwargs.get("concat_after", False) + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + if (cond is not None and "xvec" in cond and + self.conditions is not None and "xvec" in self.conditions): + xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1]) + x = pack([x, xvec], "b * t")[0] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_down = rearrange(mask_down, "b 1 t -> b t") + if self.down_block_type == "transformer": + attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + else: + attn_mask = mask_down.squeeze(1) + for transformer_block in transformer_blocks: + cond_kwargs = {} + if (cond is not None and "prompt" in cond and + self.conditions is not None and "prompt" in self.conditions): + cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"] + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + **cond_kwargs + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + if (cond is not None and "xvec" in cond and + self.conditions is not None and "xvec" in self.conditions): + xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1]) + x = pack([x, xvec], "b * t")[0] + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_mid = rearrange(mask_mid, "b 1 t -> b t") + if self.mid_block_type == "transformer": + attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + else: + attn_mask = mask_mid.squeeze(1) + for transformer_block in transformer_blocks: + cond_kwargs = {} + if (cond is not None and "prompt" in cond and + self.conditions is not None and "prompt" in self.conditions): + cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"] + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + **cond_kwargs + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + if (cond is not None and "xvec" in cond and + self.conditions is not None and "xvec" in self.conditions): + xvec = repeat(cond["xvec"], "b c -> b c t", t=x.shape[-1]) + x = pack([x, xvec], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_up = rearrange(mask_up, "b 1 t -> b t") + if self.up_block_type == "transformer": + attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + else: + attn_mask = mask_up.squeeze(1) + for transformer_block in transformer_blocks: + cond_kwargs = {} + if (cond is not None and "prompt" in cond and + self.conditions is not None and "prompt" in self.conditions): + cond_kwargs["prompt"], cond_kwargs["prompt_lengths"] = cond["prompt"] + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + **cond_kwargs + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask + + +class ConditionalDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_block_type = down_block_type + self.mid_block_type = mid_block_type + self.up_block_type = up_block_type + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + def get_block(self, block_type, dim, attention_head_dim, num_heads, dropout, act_fn, **kwargs): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_down = rearrange(mask_down, "b 1 t -> b t") + if self.down_block_type == "transformer": + attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + else: + attn_mask = mask_down.squeeze(1) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_mid = rearrange(mask_mid, "b 1 t -> b t") + if self.mid_block_type == "transformer": + attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + else: + attn_mask = mask_mid.squeeze(1) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + # mask_up = rearrange(mask_up, "b 1 t -> b t") + if self.up_block_type == "transformer": + attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + else: + attn_mask = mask_up.squeeze(1) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + # mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask \ No newline at end of file diff --git a/funasr/models/llm_asr/diffusion_models/transformer.py b/funasr/models/llm_asr/diffusion_models/transformer.py new file mode 100644 index 000000000..efe1501e9 --- /dev/null +++ b/funasr/models/llm_asr/diffusion_models/transformer.py @@ -0,0 +1,317 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + **kwargs + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states \ No newline at end of file diff --git a/funasr/models/llm_asr/flow_matching.py b/funasr/models/llm_asr/flow_matching.py new file mode 100644 index 000000000..bcbb3fc80 --- /dev/null +++ b/funasr/models/llm_asr/flow_matching.py @@ -0,0 +1,847 @@ +import logging +import math +from typing import Dict, List, Optional +import torch +import torch.nn as nn +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.train_utils.device_funcs import force_gatherable +import random +from funasr.models.llm_asr.mel_spectrum import ( + mel_spectrogram, power_spectrogram, mel_from_power_spectrogram +) +from torch.nn import functional as F +from funasr.models.transformer.utils.nets_utils import pad_list +from distutils.version import LooseVersion +from contextlib import contextmanager +from funasr.utils.hinter import hint_once +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class MelSpectrumExtractor(nn.Module): + def __init__( + self, + n_fft=1024, + num_mels=80, + sampling_rate=22050, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8000, + spec_type="mel", + ): + super().__init__() + self.n_fft = n_fft + self.num_mels = num_mels + self.sampling_rate = sampling_rate + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.spec_type = spec_type + + def extra_repr(self): + return f"n_fft={self.n_fft}, num_mels={self.num_mels}, sampling_rate={self.sampling_rate}, " \ + f"hop_size={self.hop_size}, win_size={self.win_size}, fmin={self.fmin}, fmax={self.fmax}" + + def forward(self, x, ilens): + if self.spec_type == "power": + feat = power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate, + self.hop_size, self.win_size, self.fmin, self.fmax) + else: + feat = mel_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate, + self.hop_size, self.win_size, self.fmin, self.fmax) + # determine olens by compare the lengths of inputs and outputs + olens = ilens // (x.shape[1] // feat.shape[2]) + return feat.transpose(1, 2), olens + + def convert_power_to_mel(self, x, ilens): + feat = mel_from_power_spectrogram(x, self.n_fft, self.num_mels, self.sampling_rate, + self.hop_size, self.win_size, self.fmin, self.fmax) + return feat, ilens + + +class QuantizerCodebook(torch.nn.Module): + def __init__( + self, + num_quantizers, + codebook_size, + codebook_dim, + hop_length, + sampling_rate + ): + super().__init__() + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.hop_size = hop_length + self.sampling_rate = sampling_rate + embed = torch.zeros(num_quantizers, codebook_size, codebook_dim) + self.register_buffer("embed", embed) + codec_index_shift = 1024 * torch.arange(32, dtype=torch.float32)[None, None, :] + self.register_buffer("codec_index_shift", codec_index_shift) + + def save_embedding(self, file_name, dense_emb, emb_lengths): + import kaldiio + wav_writer = kaldiio.WriteHelper("ark,scp,f:{}.ark,{}.scp".format(file_name, file_name)) + dense_emb = dense_emb.cpu().numpy() + for i in range(min(dense_emb.shape[0], 10)): + wav_writer(str(i), dense_emb[i, :emb_lengths[i]]) + + wav_writer.close() + + def forward(self, codec: torch.Tensor, codec_lengths, return_subs=False): + if len(codec.shape) == 2: + codec = codec.unsqueeze(-1) + bz, tt, nq = codec.shape[0], codec.shape[1], codec.shape[2] + codec_mask = ~make_pad_mask(codec_lengths, maxlen=codec.shape[1]).unsqueeze(-1).to(codec.device) + codec = codec * codec_mask + self.codec_index_shift[:, :, :nq].long() + codec = codec.reshape(-1, nq) + emb = self.embed.reshape(-1, self.codebook_dim) + codec_emb = F.embedding(codec, emb) # (BT, Nq, D) + dense_emb = codec_emb.sum(dim=1) + dense_emb = dense_emb.reshape(bz, tt, self.codebook_dim) + if return_subs: + sub_embs = codec_emb.reshape(bz, tt, nq, self.codebook_dim) * codec_mask.unsqueeze(-2) + return (dense_emb * codec_mask, sub_embs), codec_lengths + return dense_emb * codec_mask, codec_lengths + + +class BaseDiffWithXvec(nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 80, + xvec_size: int = 198, + output_type: str = "mel", + encoder_conf: Dict = {}, + decoder_conf: Dict = {}, + mel_feat_conf: Dict = {}, + codec_conf: Dict = {}, + length_regulator_conf: Dict = None, + prompt_conf: Dict = None, + vocab_size: int = None, + token_list: List = None, + **kwargs, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.encoder_conf = encoder_conf + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.token_list = token_list + self.output_type = output_type + self.prompt_conf = prompt_conf + self.input_frame_rate = kwargs.get("input_frame_rate", 50) + logging.info(f"input frame rate={self.input_frame_rate}") + if output_type == 'mel': + self.mel_extractor = MelSpectrumExtractor(**mel_feat_conf) + elif output_type == 'codec': + num_quantizers = codec_conf.get("num_quantizers", 32) + codebook_size = codec_conf.get("codebook_size", 1024) + codebook_dim = codec_conf.get("codebook_dim", 128) + hop_length = codec_conf.get("hop_length", 640) + sampling_rate = codec_conf.get("sampling_rate", 16000) + self.quantizer_codebook = QuantizerCodebook(num_quantizers, codebook_size, codebook_dim, + hop_length, sampling_rate) + if vocab_size is not None and vocab_size > 0: + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.xvec_proj = torch.nn.Linear(xvec_size, output_size) + self.encoder = self.build_encoder() + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + + self.decoder = self.build_decoder() + + self.length_regulator_conf = length_regulator_conf + self.length_regulator = self.build_length_regulator() + + def build_encoder(self): + encoder_name = self.encoder_conf.pop("name", "transformer") + model = None + if encoder_name == "transformer": + from funasr.models.llm_asr.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **self.encoder_conf, + input_size=self.input_size, + use_cnn_module=False, + macaron_style=False, + ) + elif encoder_name == "conformer": + from funasr.models.llm_asr.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **self.encoder_conf, + input_size=self.input_size, + ) + + self.encoder_conf["name"] = encoder_name + + return model + + def build_decoder(self): + decoder_name = self.decoder_conf.pop("name", "transformer") + model = None + + if decoder_name == "matcha": + from funasr.models.llm_asr.diffusion_models.flow_matching import CFM + model = CFM( + **self.decoder_conf, + in_channels=self.output_size * 2, # 2 for noise_y and mu + out_channel=self.output_size, + spk_emb_dim=self.output_size + ) + + self.decoder_conf["name"] = decoder_name + + return model + + def select_target_prompt(self, y, y_lengths): + prompt_conf = self.prompt_conf + prompt_list = [] + prompt_lengths = [] + for i, y_len in enumerate(y_lengths): + prompt_len = random.randint( + int(y_len * prompt_conf["prompt_with_range_ratio"][0]), + int(y_len * prompt_conf["prompt_with_range_ratio"][1]) + ) + prompt_pos = random.randint(0, y_len - prompt_len) + prompt_list.append(y[i, prompt_pos:prompt_pos+prompt_len]) + prompt_lengths.append(prompt_len) + prompt = pad_list(prompt_list, 0.0) + prompt_lengths = torch.tensor(prompt_lengths, dtype=torch.int64, device=y.device) + + if "cgf_prob" in prompt_conf and prompt_conf["cgf_prob"] > 0: + cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device) < prompt_conf["cgf_prob"] + prompt = prompt * cgf_mask + return prompt, prompt_lengths + + def build_length_regulator(self): + name = self.length_regulator_conf.pop("name", None) + model = None + if name == "upsampling": + from funasr.models.llm_asr.diffusion_models.length_regulator import UpSamplingRegulator + model = UpSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios")) + elif name == "downsampling": + from funasr.models.llm_asr.diffusion_models.length_regulator import DownSamplingRegulator + model = DownSamplingRegulator(self.output_size, self.length_regulator_conf.get("sampling_ratios")) + elif name == "interpolate": + from funasr.models.llm_asr.diffusion_models.length_regulator import InterpolateRegulator + model = InterpolateRegulator(self.output_size, **self.length_regulator_conf) + else: + raise ValueError(f"Unknown length_regulator {name}") + + self.length_regulator_conf["name"] = name + + return model + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + audio: torch.Tensor, + audio_lengths: torch.Tensor, + xvec: torch.Tensor, + xvec_lengths: torch.Tensor, + ): + batch_size = audio.shape[0] + # for data parallel + x = text[:, :text_lengths.max()] + y = audio[:, :audio_lengths.max()] + xvec = xvec[:, :xvec_lengths.max()] + if self.vocab_size is not None and self.vocab_size > 0: + mask = (x != -1).float().unsqueeze(-1) + x = self.input_embedding(torch.clamp(x, min=0)) * mask + + # random select a xvec from xvec matrix + xvec_list = [] + for i, ilen in enumerate(xvec_lengths): + idx = random.randint(0, ilen-1) + while torch.any(~torch.isfinite(xvec[i, idx])): + idx = random.randint(0, ilen - 1) + xvec_list.append(xvec[i, idx]) + rand_xvec = torch.vstack(xvec_list) + rand_xvec = self.xvec_proj(rand_xvec) + + y, y_lengths = self.extract_feat(y, audio_lengths) + h, h_lengths, _ = self.encoder(x, text_lengths) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths) + if self.prompt_conf is not None: + target_prompt = self.select_target_prompt(y, y_lengths) + conditions = dict( + xvec=rand_xvec, + target_prompt=target_prompt, + ) + else: + conditions = None + + mask = (~make_pad_mask(y_lengths)).to(y) + # y, h in (B, T, D) + loss, _ = self.decoder.compute_loss( + y.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + rand_xvec, + cond=conditions + ) + + stats = dict(loss=torch.clone(loss.detach())) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + @torch.no_grad() + def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor): + if self.output_type == 'mel': + return self.mel_extractor(y, y_lengths) + elif self.output_type == "codec": + return self.quantizer_codebook(y.long(), y_lengths) + else: + return y, y_lengths + + @torch.no_grad() + def inference(self, text, text_lens, xvec, xvec_lens, diff_steps=10, temperature=1.0, prompt=None): + avg_xvec = torch.mean(xvec, dim=1) + avg_xvec = self.xvec_proj(avg_xvec) + if self.vocab_size is not None and self.vocab_size > 0: + mask = (text != -1).float().unsqueeze(-1) + text = self.input_embedding(torch.clamp(text, min=0)) * mask + h, h_lengths, _ = self.encoder(text, text_lens) + h = self.encoder_proj(h) + if self.output_type == "mel": + coeff = ((self.mel_extractor.sampling_rate / self.mel_extractor.hop_size) / + self.input_frame_rate) + else: + coeff = ((self.quantizer_codebook.sampling_rate / self.quantizer_codebook.hop_size) / + self.input_frame_rate) + y = torch.zeros([1, int(h.shape[1] * coeff), 80], device=text.device) + y_lens = (text_lens * coeff).long() + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens) + mask = (~make_pad_mask(y_lens)).to(y) + feat = self.decoder.forward( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + n_timesteps=diff_steps, + temperature=temperature, + spks=avg_xvec, + cond=None, + ) + return feat + + def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: + pass + + +class MaskedDiffWithXvec(BaseDiffWithXvec): + def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel", + encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {}, + length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None, + token_list: List = None, **kwargs): + super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf, + codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs) + if self.prompt_conf is not None: + self.masker = self.build_masker() + self.cgf_prob = prompt_conf.get("cgf_prob", 0.0) + self.prompt_dropout_rate = prompt_conf.get("prompt_dropout", 0.0) + if self.prompt_dropout_rate > 0: + self.prompt_dropout = nn.Dropout(self.prompt_dropout_rate) + else: + self.prompt_dropout = None + self.only_mask_loss = kwargs.get("only_mask_loss", False) + self.io_ratio = kwargs.get("io_ratio", None) + if self.io_ratio == "auto": + self.io_ratio = mel_feat_conf["sampling_rate"] / mel_feat_conf["hop_size"] / self.input_frame_rate + self.first_package_conf = kwargs.get("first_package_conf", None) + self.length_normalizer_ratio = kwargs.get("length_normalizer_ratio", None) + + def build_masker(self): + prompt_type = self.prompt_conf.get("prompt_type", "free") + if prompt_type == "prefix": + from funasr.models.specaug.mask_along_axis import PrefixMaskVariableMaxWidth + masker = PrefixMaskVariableMaxWidth( + mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"], + ) + else: + from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth + masker = MaskAlongAxisVariableMaxWidth( + mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"], + num_mask=1, + ) + return masker + + @staticmethod + def norm_and_sample_xvec(xvec, xvec_lengths): + xvec_list = [] + for i, ilen in enumerate(xvec_lengths): + if ilen == 1: + idx = 0 + else: + idx = random.randint(0, ilen - 1) + while torch.any(~torch.isfinite(xvec[i, idx])): + idx = random.randint(0, ilen - 1) + if torch.any(~torch.isfinite(xvec[i, idx])): + to_add = torch.zeros_like(xvec[i, idx]) + else: + to_add = xvec[i, idx] + xvec_list.append(to_add) + rand_xvec = torch.vstack(xvec_list) + rand_xvec = F.normalize(rand_xvec, dim=1) + + return rand_xvec + + def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor): + _, _, cond_mask = self.masker(y, y_lengths, return_mask=True) + cond_mask = ~cond_mask + + if self.cgf_prob > 0: + cgf_mask = torch.rand([y.shape[0], 1, 1], dtype=torch.float32, device=y.device) + cond_mask = cond_mask * (cgf_mask > self.cgf_prob) + + return cond_mask + + def build_decoder(self): + decoder_name = self.decoder_conf.pop("name", "transformer") + model = None + + if decoder_name == "matcha": + from funasr.models.llm_asr.diffusion_models.flow_matching import CFM + model = CFM( + **self.decoder_conf, + ) + + self.decoder_conf["name"] = decoder_name + + return model + + def sample_first_package( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + audio: torch.Tensor, + audio_lengths: torch.Tensor, + ): + sample_rate = self.first_package_conf["sample_rate"] + min_token_len, max_token_len = self.first_package_conf["token_len_range"] + random_start = self.first_package_conf.get("random_start", False) + bs = text.shape[0] + sample_mask = torch.rand((bs, ), device=text_lengths.device) < sample_rate + if random_start: + text_list, text_lengths_list, audio_list, audio_lengths_list = [], [], [], [] + for i, total_len in enumerate(text_lengths): + total_len = total_len.item() + if sample_mask[i].item(): + if isinstance(min_token_len, float) and 0.0 < min_token_len <= 1.0: + min_token_len = math.floor(min_token_len * total_len) + if isinstance(max_token_len, float) and 0.0 < max_token_len <= 1.0: + max_token_len = math.floor(max_token_len * total_len) + if total_len > max_token_len > min_token_len: + fp_len = random.randint(min_token_len, max_token_len) + start = random.randint(0, total_len - fp_len) + audio_st, audio_len = self.calc_target_len(torch.tensor(start)), self.calc_target_len(torch.tensor(fp_len)) + else: + start, fp_len = 0, total_len + audio_st, audio_len = 0, self.calc_target_len(fp_len) + text_list.append(text[i, start: start+fp_len]) + text_lengths_list.append(fp_len) + audio_list.append(audio[i, audio_st: audio_st+audio_len]) + audio_lengths_list.append(audio_list[-1].shape[0]) + else: + text_list.append(text[i]) + text_lengths_list.append(text_lengths[i]) + audio_list.append(audio[i, :min(self.calc_target_len(text_lengths[i]), audio_lengths[i])]) + audio_lengths_list.append(audio_list[-1].shape[0]) + text = pad_list(text_list, pad_value=0.0).to(text) + new_text_lengths = torch.tensor(text_lengths_list, dtype=torch.int64, device=text.device) + audio = pad_list(audio_list, pad_value=0.0).to(audio) + new_audio_lengths = torch.tensor(audio_lengths_list, dtype=torch.int64, device=audio.device) + else: + fp_token_len = torch.randint(min_token_len, max_token_len + 1, (bs,)) + fp_token_len = torch.minimum(fp_token_len.to(text_lengths), text_lengths) + fp_audio_len = self.calc_target_len(fp_token_len) + fp_audio_len = torch.minimum(fp_audio_len.to(audio_lengths), audio_lengths) + new_text_lengths = torch.where(sample_mask, fp_token_len, text_lengths) + new_audio_lengths = torch.where(sample_mask, fp_audio_len, audio_lengths) + text = text * (~make_pad_mask(new_text_lengths, maxlen=text.shape[1]).unsqueeze(-1)).to(text.device) + audio = audio * (~make_pad_mask(new_audio_lengths, maxlen=audio.shape[1]).unsqueeze(-1)).to(audio.device) + + return text, new_text_lengths, audio, new_audio_lengths + + @staticmethod + def clip_both_side(y, y_lengths, raw_lengths): + res_list = [] + new_length = [] + for i, (new_len, org_len) in enumerate(zip(y_lengths, raw_lengths)): + if new_len >= org_len: + res_list.append(y[i, :new_len]) + else: + left = (org_len - new_len) // 2 + right = org_len - new_len - left + res_list.append(y[i, left: org_len-right]) + + new_length.append(res_list[-1].shape[0]) + + new_length = torch.tensor(new_length).to(y_lengths) + return pad_list(res_list, 0.0), new_length + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + audio: torch.Tensor, + audio_lengths: torch.Tensor, + xvec: Optional[torch.Tensor] = None, + xvec_lengths: Optional[torch.Tensor] = None, + ): + batch_size = audio.shape[0] + # for data parallel + with autocast(False): + x = text[:, :text_lengths.max()] + y = audio[:, :audio_lengths.max()] + if self.vocab_size is not None and self.vocab_size > 0: + mask = (x != -1).float().unsqueeze(-1) + x = self.input_embedding(torch.clamp(x, min=0)) * mask + + # random select a xvec from xvec matrix + rand_xvec = None + if xvec is not None: + xvec = xvec[:, :xvec_lengths.max()] + rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths) + rand_xvec = self.xvec_proj(rand_xvec) + + y, y_lengths = self.extract_feat(y, audio_lengths) + if self.length_normalizer_ratio is not None: + max_y_lengths = torch.round(text_lengths * self.length_normalizer_ratio).long() + raw_lengths = y_lengths.clone() + y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths) + y, new_y_lengths = self.clip_both_side(y, y_lengths, raw_lengths) + logging.info(f"normalized y_length from {raw_lengths.cpu().tolist()} to {y_lengths.cpu().tolist()} " + f"new_y_length {new_y_lengths.cpu().tolist()}, with text_lengths {text_lengths.cpu().tolist()}") + y = y[:, :new_y_lengths.max()] + elif self.io_ratio is not None: + hint_once(f"cut output with ratio {self.io_ratio}", "print_ratio", rank=0) + max_y_lengths = (text_lengths * self.io_ratio + 3).long() + if y_lengths.max() > max_y_lengths.max(): + logging.info(f"cut output with ratio {self.io_ratio} from {y_lengths.max()} to {max_y_lengths.max()}") + y_lengths = torch.where(y_lengths > max_y_lengths, max_y_lengths, y_lengths) + y = y[:, :y_lengths.max()] + + if self.first_package_conf is not None: + x, text_lengths, y, y_lengths = self.sample_first_package( + x, text_lengths, y, y_lengths + ) + x = x[:, :text_lengths.max()] + y = y[:, :y_lengths.max()] + h, _, _ = self.encoder(x, text_lengths) + h_lengths = text_lengths + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths) + if self.prompt_conf is not None: + cond_mask = self.select_target_prompt(y, y_lengths) + if self.prompt_dropout is not None: + hint_once(f"prompt dropout {self.prompt_dropout_rate}", "prompt dropout") + y = self.prompt_dropout(y) + conditions = (y * cond_mask).transpose(1, 2) + else: + cond_mask, conditions = None, None + + stats = dict( + batch_size=batch_size, + in_lengths=text_lengths.max(), + out_lengths=y_lengths.max(), + ) + + mask = (~make_pad_mask(y_lengths)).to(y) + # y, h in (B, T, D) + loss, _ = self.decoder.compute_loss( + y.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + rand_xvec, + cond=conditions, + reduction="none", + ) + loss = loss.transpose(1, 2) + all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1]) + if cond_mask is not None: + masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask) + else: + masked_loss_mask = mask.unsqueeze(-1) + masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1]) + stats["all_loss"] = all_loss.item() + stats["masked_loss"] = masked_loss.item() + + loss = masked_loss if self.only_mask_loss else all_loss + + stats["loss"] = loss.item() + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + @staticmethod + def concat_prompt(prompt, prompt_lengths, text, text_lengths): + xs_list, x_len_list = [], [] + for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)): + xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0)) + x_len_list.append(_prompt_len + _text_len) + + xs = pad_list(xs_list, pad_value=0.0) + x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device) + + return xs, x_lens + + @staticmethod + def remove_prompt(prompt, prompt_lengths, padded, padded_lengths): + xs_list = [] + for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)): + xs_list.append(padded[idx, _prompt_len: _x_len]) + + xs = pad_list(xs_list, pad_value=0.0) + + return xs, padded_lengths - prompt_lengths + + @staticmethod + def norm_and_avg_xvec(xvec: torch.Tensor, xvec_lens: torch.Tensor): + mask = torch.isfinite(xvec.norm(dim=-1, keepdim=True)) + norm_xvec = F.normalize(xvec, dim=-1) * mask + avg_xvec = F.normalize(torch.sum(norm_xvec, dim=1) / mask.sum(), dim=-1) + return avg_xvec + + def calc_target_len(self, in_len): + if self.input_frame_rate == 25 and self.output_type == "mel": + if self.length_normalizer_ratio is not None: + if isinstance(in_len, int): + in_len = torch.tensor(in_len) + ll = torch.round(in_len * self.length_normalizer_ratio) + else: + ll = (in_len * 4 + 4) * 160 + 400 + ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size + if isinstance(in_len, int): + ll = int(round(ll)) + else: + ll = torch.round(ll).long() + return ll + if self.input_frame_rate == 50 and self.output_type == "mel": + if self.length_normalizer_ratio is not None: + if isinstance(in_len, int): + in_len = torch.tensor(in_len) + ll = torch.round(in_len * self.length_normalizer_ratio) + else: + ll = (in_len * 2 + 2) * 160 + 400 + ll = ll / 16000 * self.mel_extractor.sampling_rate / self.mel_extractor.hop_size + if isinstance(in_len, int): + ll = int(round(ll)) + else: + ll = torch.round(ll).long() + return ll + elif self.output_type == "codec": + return in_len + else: + raise ValueError(f"Frame rate {self.input_frame_rate} has not implemented.") + + @torch.no_grad() + def inference(self, text, text_lens, + xvec=None, xvec_lens=None, + diff_steps=10, temperature=1.0, prompt: dict = None, y_lens=None): + rand_xvec = None + if xvec is not None: + if xvec.dim() == 2: + xvec = xvec.unsqueeze(1) + xvec_lens = torch.ones_like(xvec_lens) + rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens) + rand_xvec = self.xvec_proj(rand_xvec) + + prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None)) + prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None)) + + if self.vocab_size is not None and self.vocab_size > 0: + if prompt_text is not None: + text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens) + mask = (text != -1).float().unsqueeze(-1) + text = self.input_embedding(torch.clamp(text, min=0)) * mask + + h, h_lengths, _ = self.encoder(text, text_lens) + h = self.encoder_proj(h) + if y_lens is None: + y_lens = self.calc_target_len(text_lens) + y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device) + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens) + + # get conditions + if prompt_audio is not None: + if prompt_audio.ndim == 2: + prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens) + for i, _len in enumerate(prompt_audio_lens): + y[i, :_len] = prompt_audio[i] + conds = y.transpose(1, 2) + + mask = (~make_pad_mask(y_lens)).to(y) + feat = self.decoder.forward( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + n_timesteps=diff_steps, + temperature=temperature, + spks=rand_xvec, + cond=conds, + ) + + if prompt_text is not None and prompt_audio is not None: + feat = feat.transpose(1, 2) + feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device) + feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens) + feat = feat.transpose(1, 2) + + # if prompt_audio is not None: + # feat_rmq = torch.sqrt(torch.mean(torch.pow(feat, 2), dim=[1, 2], keepdim=True)) + # prompt_rmq = torch.sqrt(torch.mean(torch.pow(prompt_audio, 2), dim=[1, 2], keepdim=True)) + # feat = feat / feat_rmq * prompt_rmq + + return feat + + +class MaskedDiffTTS(MaskedDiffWithXvec): + + def __init__(self, input_size: int, output_size: int = 80, xvec_size: int = 198, output_type: str = "mel", + encoder_conf: Dict = {}, decoder_conf: Dict = {}, mel_feat_conf: Dict = {}, codec_conf: Dict = {}, + length_regulator_conf: Dict = None, prompt_conf: Dict = None, vocab_size: int = None, + token_list: List = None, **kwargs): + super().__init__(input_size, output_size, xvec_size, output_type, encoder_conf, decoder_conf, mel_feat_conf, + codec_conf, length_regulator_conf, prompt_conf, vocab_size, token_list, **kwargs) + self.length_loss_weight = kwargs.get("length_loss_weight", 0.0) + if self.length_loss_weight > 0.0: + self.length_predictor = nn.Linear(self.encoder.output_size(), 1) + + def calc_target_len(self, enc_outs, enc_lens): + text_durs = self.length_predictor(enc_outs) + text_durs = torch.exp(text_durs) + mask = ~make_pad_mask(enc_lens, xs=text_durs) + utt_durs = (text_durs * mask).sum(dim=1).squeeze(-1) + return utt_durs + + def forward(self, text: torch.Tensor, text_lengths: torch.Tensor, audio: torch.Tensor, audio_lengths: torch.Tensor, + xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None): + batch_size = audio.shape[0] + # for data parallel + x = text[:, :text_lengths.max()] + y = audio[:, :audio_lengths.max()] + if self.vocab_size is not None and self.vocab_size > 0: + mask = (x != -1).float().unsqueeze(-1) + x = self.input_embedding(torch.clamp(x, min=0)) * mask + + # random select a xvec from xvec matrix + rand_xvec = None + if xvec is not None: + xvec = xvec[:, :xvec_lengths.max()] + rand_xvec = self.norm_and_sample_xvec(xvec, xvec_lengths) + rand_xvec = self.xvec_proj(rand_xvec) + + y, y_lengths = self.extract_feat(y, audio_lengths) + h, _, _ = self.encoder(x, text_lengths) + h_lengths = text_lengths + h_durs = self.calc_target_len(h, h_lengths) + utt_dur_loss = self.length_loss_weight * F.l1_loss(h_durs, y_lengths) + + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lengths) + if self.prompt_conf is not None: + cond_mask = self.select_target_prompt(y, y_lengths) + conditions = (y * cond_mask).transpose(1, 2) + else: + cond_mask, conditions = None, None + + stats = dict( + batch_size=batch_size, + in_lengths=text_lengths.max(), + out_lengths=y_lengths.max(), + ) + + mask = (~make_pad_mask(y_lengths)).to(y) + # y, h in (B, T, D) + loss, _ = self.decoder.compute_loss( + y.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + rand_xvec, + cond=conditions, + reduction="none", + ) + loss = loss.transpose(1, 2) + all_loss = (loss * mask.unsqueeze(-1)).sum() / (mask.sum() * loss.shape[-1]) + if cond_mask is not None: + masked_loss_mask = mask.unsqueeze(-1) * (~cond_mask) + else: + masked_loss_mask = mask.unsqueeze(-1) + masked_loss = (loss * masked_loss_mask).sum() / (masked_loss_mask.sum() * loss.shape[-1]) + stats["all_loss"] = all_loss.item() + stats["masked_loss"] = masked_loss.item() + + loss = masked_loss if self.only_mask_loss else all_loss + stats["mel_loss"] = loss.item() + + loss = loss + utt_dur_loss + stats["loss"] = loss.item() + stats["utt_dur_loss"] = utt_dur_loss.item() + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def inference(self, text, text_lens, xvec=None, xvec_lens=None, diff_steps=10, temperature=1.0, + prompt: dict = None): + rand_xvec = None + if xvec is not None: + if xvec.dim() == 2: + xvec = xvec.unsqueeze(1) + xvec_lens = torch.ones_like(xvec_lens) + rand_xvec = self.norm_and_avg_xvec(xvec, xvec_lens) + rand_xvec = self.xvec_proj(rand_xvec) + + prompt_text, prompt_text_lens = prompt.get("prompt_text", (None, None)) + prompt_audio, prompt_audio_lens = prompt.get("prompt_audio", (None, None)) + + if self.vocab_size is not None and self.vocab_size > 0: + if prompt_text is not None: + text, text_lens = self.concat_prompt(prompt_text, prompt_text_lens, text, text_lens) + mask = (text != -1).float().unsqueeze(-1) + text = self.input_embedding(torch.clamp(text, min=0)) * mask + + h, _, _ = self.encoder(text, text_lens) + h_lengths = text_lens + y_lens = self.calc_target_len(h, h_lengths).round().long() + y = torch.zeros([1, y_lens.max().item(), self.output_size], device=text.device) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, h_lengths, y, y_lens) + + # get conditions + if prompt_audio is not None: + if prompt_audio.ndim == 2: + prompt_audio, prompt_audio_lens = self.extract_feat(prompt_audio, prompt_audio_lens) + for i, _len in enumerate(prompt_audio_lens): + y[i, :_len] = prompt_audio[i] + conds = y.transpose(1, 2) + + mask = (~make_pad_mask(y_lens)).to(y) + feat = self.decoder.forward( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + n_timesteps=diff_steps, + temperature=temperature, + spks=rand_xvec, + cond=conds, + ) + + if prompt_text is not None and prompt_audio is not None: + feat = feat.transpose(1, 2) + feat_lens = torch.tensor([feat.shape[1]], dtype=torch.int64, device=feat.device) + feat, feat_lens = self.remove_prompt(None, prompt_audio_lens, feat, feat_lens) + feat = feat.transpose(1, 2) + + return feat + diff --git a/funasr/models/llm_asr/hifigan.py b/funasr/models/llm_asr/hifigan.py new file mode 100644 index 000000000..6ab0059b2 --- /dev/null +++ b/funasr/models/llm_asr/hifigan.py @@ -0,0 +1,477 @@ +# Copyright 2023 KaiHu +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HIFI-GAN""" + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple, List, Union +import typing as tp +import torch +import torchaudio +from torch import nn +import torch.nn.functional as F +from typeguard import check_argument_types +from funasr.train_utils.device_funcs import force_gatherable +from librosa.filters import mel as librosa_mel_fn +import logging +from funasr.utils.hinter import hint_once + + +class Audio2Mel(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + sampling_rate=22050, + n_mel_channels=80, + mel_fmin=0.0, + mel_fmax=None, + center=False, + device='cuda', + feat_type="power_log", + ): + super().__init__() + ############################################## + # FFT Parameters # + ############################################## + window = torch.hann_window(win_length, device=device).float() + mel_basis = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float().to(device) + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("window", window) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.center = center + self.feat_type = feat_type + + def forward(self, audioin): + p = (self.n_fft - self.hop_length) // 2 + audio = F.pad(audioin, (p, p), "reflect").squeeze(1) + fft = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + if self.feat_type == "mag_log10": + power_spec = torch.sqrt(torch.sum(torch.pow(fft, 2), dim=[-1])) + mel_output = torch.matmul(self.mel_basis, power_spec) + return torch.log10(torch.clamp(mel_output, min=1e-5)) + power_spec = torch.sum(torch.pow(fft, 2), dim=[-1]) + mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9)) + return self.spectral_normalize(mel_spec) + + + @classmethod + def spectral_normalize(cls, spec, C=1, clip_val=1e-5): + output = cls.dynamic_range_compression(spec, C, clip_val) + return output + + @classmethod + def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5): + output = cls.dynamic_range_decompression(spec, C, clip_val) + return output + + @staticmethod + def dynamic_range_compression(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + @staticmethod + def dynamic_range_decompression(x, C=1): + return torch.exp(x) / C + + +class HifiGan(nn.Module): + """HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator). + NSF-HIFIGAN, HiFTNet Optional. + """ + + def __init__( + self, + input_size: int, + frontend: torch.nn.Module = None, + nsf_augmented: bool = False, + f0_predictor: dict = None, + generator: dict = None, + discriminator: dict = None, + target_sample_hz: int = 22_050, + multi_mel_spectral_window: Union[Tuple, List] = tuple([1024]), + multi_mel_spectral_hop: Union[Tuple, List] = tuple([256]), + multi_mel_spectral_fft: Union[Tuple, List] = tuple([1024]), + multi_mel_spectral_n_mels: Union[Tuple, List] = tuple([80]), + mel_fmin: float = 0, + mel_fmax: float = 8000, + mel_fmax_for_loss: Optional[float] = None, + multi_mel_spectral_recon_loss_weight: Union[Tuple[float], List[float]] = tuple([45]), + adversarial_loss_weight: float = 1.0, + feat_match_loss_weight: float = 2.0, + tpr_loss_params: tp.Dict[str, tp.Any] = {"weight": 0.0, "tau": 0.04}, + mel_feat_type="power_log", + ): + """Initialize HifiGan model. + Args: + f0_predictor: f0 predictor (pretrained && frozen) for NSF-HIFIGAN, Optional. + generator: hifigan generator + discriminator: several discriminators, such as MSD, MPD, MRD + multi_mel_spectral_window: stft window length + multi_mel_spectral_hop: stft hop length + multi_mel_spectral_fft: fft bins + multi_mel_spectral_n_mels: Mel frequency bins + mel_fmin: fmin for mel + mel_fmax: fmax for mel + mel_fmax_for_loss: fmax for multi mel spectral loss + multi_mel_spectral_recon_loss_weight: the weight of frequency-domain reconstruction loss + adversarial_loss_weight: the weight of adversarial loss from discriminator + feat_match_loss_weight: the weight of intermediate feature loss from discriminator + tpr_loss_params: the weight and tau of Truncated Pointwise Relativistic (TPR) loss from discriminator. + """ + super().__init__() + + self.decoder = self.build_decoder(generator) + # Used by task and trainer + self.gen_model_list = [self.decoder] + + # nsf-hifigan or original hifigan + self.nsf_augmented = nsf_augmented + if nsf_augmented: + assert f0_predictor is not None + self.f0_predictor = self.build_f0_predictor(f0_predictor) + # frozen + for param in self.f0_predictor.parameters(): + param.requires_grad = False + self.gen_model_list.append(self.f0_predictor) + + self.discriminator = self.build_discriminator(discriminator) + + self.multi_mel_spec_transforms = nn.ModuleList() + for n_fft, hop_len, win_len, n_mel in zip(multi_mel_spectral_fft, multi_mel_spectral_hop, + multi_mel_spectral_window, multi_mel_spectral_n_mels): + self.multi_mel_spec_transforms.append( + Audio2Mel( + n_fft=n_fft, + hop_length=hop_len, + win_length=win_len, + sampling_rate=target_sample_hz, + n_mel_channels=n_mel, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax_for_loss, + center=False, + ) + ) + + self.mel_spec_transform = Audio2Mel( + n_fft=multi_mel_spectral_fft[0], + hop_length=multi_mel_spectral_hop[0], + win_length=multi_mel_spectral_window[0], + sampling_rate=target_sample_hz, + n_mel_channels=multi_mel_spectral_n_mels[0], + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + center=False, + feat_type=mel_feat_type, + ) + + # loss weights + self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight + self.adversarial_loss_weight = adversarial_loss_weight + self.feat_match_loss_weight = feat_match_loss_weight + self.tpr_loss_weight = tpr_loss_params.get("weight", 0.0) + self.tpr_loss_tau = tpr_loss_params.get("tau", 0.04) + self.register_buffer('zero', torch.tensor([0.]), persistent=False) + self.gen_loss = 0 + self.sample_rate = target_sample_hz + self.forward_step = 0 + + def build_decoder(self, conf): + from funasr.models.llm_asr.hifigan_module.generator import HiFTGenerator + return HiFTGenerator(**conf) + + def build_f0_predictor(self, conf): + from funasr.models.llm_asr.hifigan_module.nsf_utils import ConvRNNF0Predictor + return ConvRNNF0Predictor(**conf) + + def build_discriminator(self, conf): + from funasr.models.llm_asr.hifigan_module.discriminator import MultipleDiscriminator + return MultipleDiscriminator(**conf) + + @property + def generator(self): + return torch.nn.ModuleList(self.gen_model_list) + + def forward( + self, + forward_generator: bool = True, + batch: Dict = None, + ) -> Dict[str, Any]: + """Forward functions of generator and discriminator. + + Args: + forward_generator (bool): Whether to forward generator. + batch (Dict[str, Tensor]): one batch including: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + + Returns: + Dict[str, Any]: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + - weight (Tensor): Weight tensor to summarize losses. + - optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + if forward_generator: + if self.training: + self.forward_step += 1 + return self._forward_generator( + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + ) + else: + return self._forward_discriminator( + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + """Given a tensor `x`, returns the encoded representation for `x` + """ + assert x.dim() == 3 + _, channel, length = x.size() + assert channel == 1 + mel = self.mel_spec_transform(x) + return mel.squeeze() + + @torch.no_grad() + def _f0_pred(self, x: torch.Tensor) -> torch.Tensor: + """Given a tensor `x`, return the predicted f0 for `x`, x in (B, C, T) + """ + if self.nsf_augmented: + f0 = self.f0_predictor(x) + if len(f0.shape) == 1: + f0 = f0.unsqueeze(0) + return f0 + else: + return torch.zeros_like(x) + + def _decode(self, x: torch.Tensor, g: Union[torch.Tensor] = None) -> torch.Tensor: + """Decode the given representation into a waveform. + + Args: + x (Tensor): Speech representation tensor (B, C1, T) + g (Tensor): Global conditional vector (B, C2, 1). + """ + if self.nsf_augmented: + f0 = self._f0_pred(x) + return self.decoder(x, f0, g) + else: + return self.decoder(x, g) + + def _forward_generator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + batch_size = speech.size(0) + speech = speech.unsqueeze(1) + orig_speech = speech.clone() + + mel = self._encode(speech) # [B, C, T] + recon_speech = self._decode(mel)[:, :, :speech.shape[-1]] + + # L1 Mel-Spectrogram Loss + multi_mel_recon_loss = self.zero + for lamda, mel_transform in zip(self.multi_mel_spectral_recon_loss_weight, self.multi_mel_spec_transforms): + orig_mel, recon_mel = map(mel_transform, (orig_speech, recon_speech)) + multi_mel_recon_loss = multi_mel_recon_loss + lamda * F.l1_loss(orig_mel, recon_mel) + + # calculate discriminator outputs + # disc_outputs in the format [disc1_outputs, disc2_outputs, ...] + # disc1_outputs includes [logits, intermediates] + # intermediates includes [layer_1_intermediate, layer_2_intermediate, ...] + fake_disc_outputs = self.discriminator(recon_speech) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + real_disc_outputs = self.discriminator(orig_speech) + + # calculate discriminator loss including adversarial, feat matching losses and tpr losses [Optional] + adversarial_losses = [] + disc_feature_losses = [] + tpr_losses = [] + for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs): + real_logits, real_intermediates = real_output + fake_logits, fake_intermediates = fake_output + adversarial_losses.append(torch.mean((1 - fake_logits)**2)) + for real_inter, fake_inter in zip(real_intermediates, fake_intermediates): + _loss = torch.mean(torch.abs(real_inter.detach() - fake_inter)) + disc_feature_losses.append(_loss) + + if self.tpr_loss_weight > 0.0: + tau = self.tpr_loss_tau + m_DG = torch.median((fake_logits - real_logits)) + L_rel = torch.mean((((fake_logits - real_logits) - m_DG) ** 2)[fake_logits < real_logits + m_DG]) + tpr_losses.append(tau - F.relu(tau - L_rel)) + + adversarial_loss = torch.stack(adversarial_losses).sum() + feat_match_loss = torch.stack(disc_feature_losses).sum() + tpr_loss = torch.zeros_like(adversarial_loss) + if len(tpr_losses) > 0: + tpr_loss = torch.stack(tpr_losses).sum() + + # calculate losses + gen_loss = multi_mel_recon_loss + \ + adversarial_loss * self.adversarial_loss_weight + \ + feat_match_loss * self.feat_match_loss_weight + \ + tpr_loss * self.tpr_loss_weight + self.gen_loss += gen_loss.item() + loss = gen_loss + + stats = dict( + generator_loss=loss.item(), + generator_multi_mel_recon_loss=multi_mel_recon_loss.item(), + generator_adv_loss=adversarial_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + generator_tpr_loss=tpr_loss.item(), + batch_size=batch_size, + batch_length=speech.shape[2], + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return { + "loss": loss, + "stats": stats, + "weight": weight, + "optim_idx": 0, # needed for trainer + "real": orig_speech, + "fake": recon_speech, + } + + def _forward_discriminator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Dict[str, Any]: + """Perform discriminator forward. + + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + """ + # setup + batch_size = speech.size(0) + speech = speech.unsqueeze(1) + orig_speech = speech.clone() + + # A: calculate generator outputs + with torch.no_grad(): + # do not store generator gradient in discriminator turn + mel = self._encode(speech) # [B, C, T] + recon_speech = self._decode(mel)[:, :, :speech.shape[-1]] + + # B: calculate discriminator outputs + real, fake = orig_speech.clone(), recon_speech.detach() + real_disc_outputs = self.discriminator(real) + fake_disc_outputs = self.discriminator(fake) + + # C: calculate discriminator losses, tpr losses [Optional] + disc_losses = [] + tpr_losses = [] + for real_output, fake_output in zip(real_disc_outputs, fake_disc_outputs): + real_logits, real_intermediates = real_output + fake_logits, fake_intermediates = fake_output + one_disc_loss = torch.mean((1-real_logits) ** 2) + torch.mean((0 - fake_logits) ** 2) + disc_losses.append(one_disc_loss) + + if self.tpr_loss_weight > 0.0: + tau = self.tpr_loss_tau + m_DG = torch.median((real_logits - fake_logits)) + L_rel = torch.mean((((real_logits - fake_logits) - m_DG) ** 2)[real_logits < fake_logits + m_DG]) + tpr_losses.append(tau - F.relu(tau - L_rel)) + + disc_loss = torch.stack(disc_losses).sum() + tpr_loss = torch.zeros_like(disc_loss) + if len(tpr_losses) > 0: + tpr_loss = torch.stack(tpr_losses).sum() + + self.gen_loss = 0 + + loss = disc_loss + self.tpr_loss_weight * tpr_loss + + stats = dict( + discriminator_total_loss=loss.item(), + discriminator_loss=disc_loss.item(), + discriminator_tpr_loss=tpr_loss.item(), + ) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return { + "loss": loss, + "stats": stats, + "weight": weight, + "optim_idx": 1, # needed for trainer + "real": orig_speech, + "fake": recon_speech, + } + + def inference( + self, + x: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + x (torch.Tensor): input representation, B x T x C + + Returns: + Dict[str, Tensor]: + * recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav). + + """ + + recon_speech = self._decode(x.transpose(1, 2)).squeeze(1) + retval = dict( + recon_speech=recon_speech, + ) + return retval + + def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: + pass + + @property + def input_size(self): + return diff --git a/funasr/models/llm_asr/hifigan_module/__init__.py b/funasr/models/llm_asr/hifigan_module/__init__.py new file mode 100644 index 000000000..0ed0f93a9 --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/__init__.py @@ -0,0 +1,14 @@ + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +from cosyvoice.modules.hifigan_module.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator +from cosyvoice.modules.hifigan_module.discriminator import MultipleDiscriminator +from cosyvoice.modules.hifigan_module.nsf_utils import ConvRNNF0Predictor diff --git a/funasr/models/llm_asr/hifigan_module/activations.py b/funasr/models/llm_asr/hifigan_module/activations.py new file mode 100644 index 000000000..61f2808a5 --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/funasr/models/llm_asr/hifigan_module/discriminator.py b/funasr/models/llm_asr/hifigan_module/discriminator.py new file mode 100644 index 000000000..959b515e4 --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/discriminator.py @@ -0,0 +1,299 @@ +"""hifigan based dicriminator implementation. + +This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import typing as tp + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv2d, AvgPool1d, Conv1d +from torch.nn.utils import weight_norm, spectral_norm + +from funasr.models.llm_asr.hifigan_module import get_padding + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, + use_spectral_norm=False, lrelu_slope=0.1): + super(DiscriminatorP, self).__init__() + self.period = period + self.lrelu_slope = lrelu_slope + + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 32, + 128, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 128, + 512, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 512, + 1024, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, + in_channels: int = 1, + periods: tp.List[int] = [2, 3, 5, 7, 11]): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(p) for p in periods + ]) + + def forward(self, x: torch.Tensor, return_intermediates: bool = True): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + # outs += [f(x)] + if return_intermediates: + outs.append(f(x)) + else: + outs.append(f(x)[0]) + + return outs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False, lrelu_slope=0.1): + super(DiscriminatorS, self).__init__() + self.lrelu_slope = lrelu_slope + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self, in_channels: int = 1, nb_scales: int = 3): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, x: torch.Tensor, return_intermediates: bool = True): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for i, f in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i - 1](x) + if return_intermediates: + outs.append(f(x)) + else: + outs.append(f(x)[0]) + + return outs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + stft_params: tp.List[int], + lrelu_slope: float = 0.1, + use_spectral_norm: bool = False, + ): + super().__init__() + + self.stft_params = stft_params + self.lrelu_slope = lrelu_slope + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.stft_params + x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') + x = x.squeeze(1) + spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length, + center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + + spec = torch.view_as_real(spec) # [B, F, TT, 2] + mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT] + + return mag + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x).unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + in_channels: int, + fft_sizes: tp.List[int] = [1024, 2048, 512], + hop_sizes: tp.List[int] = [120, 240, 50], + win_lengths: tp.List[int] = [600, 1200, 240], + lrelu_slope: float = 0.1, + ): + super().__init__() + + self.discriminators = nn.ModuleList() + + for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths): + self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope)) + + def forward(self, x: torch.Tensor, return_intermediates: bool = True): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + if return_intermediates: + outs.append(f(x)) + else: + outs.append(f(x)[0]) + + return outs + + +class MultipleDiscriminator(nn.Module): + def __init__( + self, + input_size: int = 1, + disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None, + ): + super().__init__() + + self.support_disc_choices = dict( + mpd=MultiPeriodDiscriminator, + msd=MultiScaleDiscriminator, + mrd=MultiResolutionDiscriminator, + ) + + self.discriminators = nn.ModuleList() + self.discriminator_type_lst = [] + for args in disc_conf_list: + assert "name" in args, "disc_conf must have `name` attr to specific disc type." + disc_type = args.pop("name") + assert disc_type in self.support_disc_choices, \ + "Unsupported discriminator type, only support {}".format( + ",".join(self.support_disc_choices.keys()) + ) + + disc_class = self.support_disc_choices[disc_type] + one_disc = disc_class(in_channels=input_size, **args) + self.discriminators.append(one_disc) + # add back to the args for dump config.yaml + args["name"] = disc_type + self.discriminator_type_lst.append(disc_type) + + def get_discriminator_type_lst(self) -> tp.List[str]: + return self.discriminator_type_lst + + def forward(self, x, return_intermediates=True): + retval = [] + for disc in self.discriminators: + out = disc(x, return_intermediates=return_intermediates) + if isinstance(out, tuple): + retval.append(out) + elif isinstance(out, list): + retval.extend(out) + else: + raise TypeError("The return value of discriminator must be tuple or list[tuple]") + + return retval \ No newline at end of file diff --git a/funasr/models/llm_asr/hifigan_module/generator.py b/funasr/models/llm_asr/hifigan_module/generator.py new file mode 100644 index 000000000..00d7a15bc --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/generator.py @@ -0,0 +1,621 @@ +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + +import typing as tp + +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm +from torch.nn.utils import remove_weight_norm + +from funasr.models.llm_asr.hifigan_module import get_padding, init_weights +from funasr.models.llm_asr.hifigan_module.activations import Snake, SnakeBeta +from funasr.models.llm_asr.hifigan_module.nsf_utils import SourceModule, SourceModuleHnNSF + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: tp.List[int] = [1, 3, 5], + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1}, + ): + super(ResBlock, self).__init__() + self.use_additional_convs = use_additional_convs + + self.convs1 = nn.ModuleList() + if use_additional_convs: + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + + if use_additional_convs: + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + + self.convs1.apply(init_weights) + if use_additional_convs: + self.convs2.apply(init_weights) + + if nonlinear_activation == "LeakyReLU": + self.activations1 = nn.ModuleList([ + nn.LeakyReLU(nonlinear_activation_params["negative_slope"]) + for _ in range(len(self.convs1)) + ]) + if use_additional_convs: + self.activations2 = nn.ModuleList([ + nn.LeakyReLU(nonlinear_activation_params["negative_slope"]) + for _ in range(len(self.convs2)) + ]) + + elif nonlinear_activation == "Snake": + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False)) + for _ in range(len(self.convs1)) + ]) + if use_additional_convs: + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False)) + for _ in range(len(self.convs2)) + ]) + + elif nonlinear_activation == "SnakeBeta": + self.activations1 = nn.ModuleList([ + SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False)) + for _ in range(len(self.convs1)) + ]) + if use_additional_convs: + self.activations2 = nn.ModuleList([ + SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False)) + for _ in range(len(self.convs2)) + ]) + + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + if self.use_additional_convs: + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + if self.use_additional_convs: + remove_weight_norm(self.convs2[idx]) + + +class HifiGenerator(nn.Module): + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + global_channels: int = -1, + upsample_rates: tp.List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: tp.List[int] = [3, 7, 11], + resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_nonlinear_activation: str = "LeakyReLU", + resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1}, + use_additional_convs: bool = True, + cond_in_each_up_layer: bool = False, + lrelu_slope: float = 0.1, + act_pre_each_up_layer: bool = True + ): + super(HifiGenerator, self).__init__() + + self.out_channels = 1 + self.global_channels = global_channels + self.use_additional_convs = use_additional_convs + self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False + self.lrelu_slope = lrelu_slope + self.act_pre_each_up_layer = act_pre_each_up_layer + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d, use_additional_convs, + resblock_nonlinear_activation, + resblock_nonlinear_activation_params)) + + if self.global_channels > 0: + self.conv_global_cond = weight_norm( + Conv1d(global_channels, base_channels, 1) + ) + self.conv_global_cond.apply(init_weights) + + if self.cond_in_each_up_layer: + self.conv_conds = nn.ModuleList() + for i in range(len(self.ups)): + self.conv_conds.append(weight_norm( + nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1)) + ) + self.conv_conds.apply(init_weights) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def output_size(self): + return self.out_channels + + def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + # x in (B, in_channels, T), g in (B, global_channels, 1) + x = self.conv_pre(x) + if self.global_channels > 0 and g is not None: + x = x + self.conv_global_cond(g) + + for i in range(self.num_upsamples): + if self.act_pre_each_up_layer: + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if self.cond_in_each_up_layer and g is not None: + x = x + self.conv_conds[i](g) + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + if self.global_channels > 0: + remove_weight_norm(self.conv_global_cond) + if self.cond_in_each_up_layer: + for l in self.conv_conds: + remove_weight_norm(l) + + +class NsfHifiGenerator(nn.Module): + """ + Neural Source Filter + HifiGan + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + global_channels: int = -1, + nb_harmonics: int = 7, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: tp.List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: tp.List[int] = [3, 7, 11], + resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_nonlinear_activation: str = "LeakyReLU", + resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1}, + use_additional_convs: bool = True, + cond_in_each_up_layer: bool = False, + lrelu_slope: float = 0.1, + act_pre_each_up_layer: bool = True + ): + super(NsfHifiGenerator, self).__init__() + + self.out_channels = 1 + self.global_channels = global_channels + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.use_additional_convs = use_additional_convs + self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False + self.lrelu_slope = lrelu_slope + self.act_pre_each_up_layer = act_pre_each_up_layer + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1], + sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold) + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # Down + self.source_downs = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, u in enumerate(downsample_cum_rates[::-1]): + if (u == 1): + self.source_downs.append( + weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1)) + ) + else: + self.source_downs.append( + weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d, use_additional_convs, + resblock_nonlinear_activation, + resblock_nonlinear_activation_params)) + + if self.global_channels > 0: + self.conv_global_cond = weight_norm( + Conv1d(global_channels, base_channels, 1) + ) + self.conv_global_cond.apply(init_weights) + + if self.cond_in_each_up_layer: + self.conv_conds = nn.ModuleList() + for i in range(len(self.ups)): + self.conv_conds.append(weight_norm( + nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1)) + ) + self.conv_conds.apply(init_weights) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def output_size(self): + return self.out_channels + + def _f02source(self, f0: torch.Tensor) -> torch.Tensor: + return self.source_module(f0.unsqueeze(1)) + + def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1) + + s = self._f02source(f0) + + x = self.conv_pre(x) + if self.global_channels > 0 and g is not None: + x = x + self.conv_global_cond(g) + + for i in range(self.num_upsamples): + if self.act_pre_each_up_layer: + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if self.cond_in_each_up_layer and g is not None: + x = x + self.conv_conds[i](g) + + # fusion + x = x + self.source_downs[i](s) + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + if self.global_channels > 0: + remove_weight_norm(self.conv_global_cond) + if self.cond_in_each_up_layer: + for l in self.conv_conds: + remove_weight_norm(l) + self.source_module.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + global_channels: int = -1, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: tp.List[int] = [8, 8], + upsample_kernel_sizes: tp.List[int] = [16, 16], + istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: tp.List[int] = [3, 7, 11], + resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_nonlinear_activation: str = "Snake", + resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False}, + source_resblock_kernel_sizes: tp.List[int] = [7, 11], + source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]], + source_resblock_nonlinear_activation: str = "Snake", + source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False}, + use_additional_convs: bool = True, + cond_in_each_up_layer: bool = False, + lrelu_slope: float = 0.1, + act_pre_each_up_layer: bool = True, + audio_limit: float = 0.99, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.global_channels = global_channels + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.use_additional_convs = use_additional_convs + self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False + self.lrelu_slope = lrelu_slope + self.act_pre_each_up_layer = act_pre_each_up_layer + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, + source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d, + use_additional_convs, source_resblock_nonlinear_activation, + source_resblock_nonlinear_activation_params) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d, use_additional_convs, + resblock_nonlinear_activation, + resblock_nonlinear_activation_params)) + + if self.global_channels > 0: + self.conv_global_cond = weight_norm( + Conv1d(global_channels, base_channels, 1) + ) + self.conv_global_cond.apply(init_weights) + + if self.cond_in_each_up_layer: + self.conv_conds = nn.ModuleList() + for i in range(len(self.ups)): + self.conv_conds.append(weight_norm( + nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1)) + ) + self.conv_conds.apply(init_weights) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.register_buffer("stft_window", window) + + def output_size(self): + return self.out_channels + + def _f02source(self, f0: torch.Tensor) -> torch.Tensor: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, _, _ = self.m_source(f0) + return har_source.transpose(1, 2) + + def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1) + + s = self._f02source(f0) + + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + if self.global_channels > 0 and g is not None: + x = x + self.conv_global_cond(g) + + for i in range(self.num_upsamples): + if self.act_pre_each_up_layer: + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if self.cond_in_each_up_layer and g is not None: + x = x + self.conv_conds[i](g) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + if self.global_channels > 0: + remove_weight_norm(self.conv_global_cond) + if self.cond_in_each_up_layer: + for l in self.conv_conds: + remove_weight_norm(l) + self.source_module.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window, + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[...,0], spec[...,1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft( + torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1), + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window) + + return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation diff --git a/funasr/models/llm_asr/hifigan_module/mel_spectrum.py b/funasr/models/llm_asr/hifigan_module/mel_spectrum.py new file mode 100644 index 000000000..37e768e75 --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/mel_spectrum.py @@ -0,0 +1,93 @@ +import torch +import torch.utils.data +import numpy as np +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + spec = spectral_normalize_torch(spec) + + return spec + + +def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + global mel_basis, hann_window + spec = spectral_de_normalize_torch(spec) + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/funasr/models/llm_asr/hifigan_module/nsf_utils.py b/funasr/models/llm_asr/hifigan_module/nsf_utils.py new file mode 100644 index 000000000..66d955c35 --- /dev/null +++ b/funasr/models/llm_asr/hifigan_module/nsf_utils.py @@ -0,0 +1,253 @@ +""" +Neural Source Filter based modules implementation. + +Neural source-filter waveform models for statistical parametric speech synthesis + +""" + +import numpy as np +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm, remove_weight_norm +from torch.distributions.uniform import Uniform +from torch.distributions.normal import Normal + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2)) + sine_wavs = sine_wavs.transpose(1,2) + uv = uv.transpose(1,2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class SourceModule(torch.nn.Module): + def __init__(self, + nb_harmonics: int, + upsample_ratio: int, + sampling_rate: int, + alpha: float = 0.1, + sigma: float = 0.003, + voiced_threshold: float = 10 + ): + super(SourceModule, self).__init__() + + self.nb_harmonics = nb_harmonics + self.upsample_ratio = upsample_ratio + self.sampling_rate = sampling_rate + self.alpha = alpha + self.sigma = sigma + self.voiced_threshold = voiced_threshold + + self.ffn = nn.Sequential( + weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)), + nn.Tanh()) + + def f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def forward(self, f0): + """ + :param f0: [B, 1, frame_len], Hz + :return: [B, 1, sample_len] + """ + with torch.no_grad(): + uv = self.f02uv(f0) + f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest') + uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest') + + F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device) + for i in range(self.nb_harmonics + 1): + F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + n_dist = Normal(loc=0., scale=self.sigma) + noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device) + + e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise + e_unvoice = self.alpha / 3 / self.sigma * noise + + e = e_voice * uv_samples + e_unvoice * (1 - uv_samples) + + return self.ffn(e) + + def remove_weight_norm(self): + remove_weight_norm(self.ffn[0]) + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512, + use_cond_rnn: bool = True, + bidirectional_rnn: bool = False, + ): + + super().__init__() + + self.num_class = num_class + self.use_cond_rnn = use_cond_rnn + + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + + if self.use_cond_rnn: + self.rnn = nn.GRU( + cond_channels, + cond_channels // 2 if bidirectional_rnn else cond_channels, + num_layers=1, + batch_first=True, + bidirectional=bidirectional_rnn, + ) + + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + if self.use_cond_rnn: + x, _ = self.rnn(x.transpose(1, 2)) + else: + x = x.transpose(1, 2) + + return torch.abs(self.classifier(x).squeeze(-1)) + + + diff --git a/funasr/models/llm_asr/mel_spectrum.py b/funasr/models/llm_asr/mel_spectrum.py new file mode 100644 index 000000000..37e768e75 --- /dev/null +++ b/funasr/models/llm_asr/mel_spectrum.py @@ -0,0 +1,93 @@ +import torch +import torch.utils.data +import numpy as np +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + spec = spectral_normalize_torch(spec) + + return spec + + +def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + global mel_basis, hann_window + spec = spectral_de_normalize_torch(spec) + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 539fb5362..911dd291e 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1,4 +1,6 @@ import logging +import os.path +import torchaudio from typing import Union, Dict, List, Tuple, Optional import time @@ -1580,6 +1582,29 @@ class LLMASR5(nn.Module): reduction=False, ) + mel_decoder_name = kwargs.get("mel_decoder", None) + mel_decoder_conf = kwargs.get("mel_decoder_conf", None) + self.mel_decoder = self.build_mel_decoder(name=mel_decoder_name, conf=mel_decoder_conf) + vocoder_name = kwargs.get("vocoder", None) + vocoder_conf = kwargs.get("vocoder_conf", None) + self.vocoder = self.build_vocoder(name=vocoder_name, conf=vocoder_conf) + + def build_mel_decoder(self, name: str, conf: dict): + if name is None or conf is None: + return None + if name == "MaskedDiffWithXvec": + from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec + return MaskedDiffWithXvec(**conf) + return None + + def build_vocoder(self, name: str, conf: dict): + if name is None or conf is None: + return None + if name == "HifiGAN": + from funasr.models.llm_asr.hifigan import HifiGan + return HifiGan(**conf) + return None + def build_audio_decoder(self, name, conf): if name == "transformer": from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM @@ -2205,7 +2230,16 @@ class LLMASR5(nn.Module): target_ids = generated_ids["sequences"] target_emb = self.llm.model.get_input_embeddings()(target_ids) if self.concat_emb_hidden: - hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1) + if not self.concat_emb_hidden_norm: + hidden_states_select = torch.concat((hidden_states_select, target_emb), dim=-1) + hidden_states_select = self.audio_decoder_in_proj(hidden_states_select) + else: + outs = self.hidden_norm(hidden_states_select) + outs = self.fusion_dropout(self.fusion_act(outs)) + # emb = model_outputs.hidden_states[0] + emb = self.fusion_dropout(self.fusion_act(self.emb_norm(target_emb))) + outs = self.audio_decoder_in_proj(torch.cat([outs, emb], dim=-1)) + hidden_states_select = self.fusion_act(self.fusion_norm(outs)) speech_tokens = self.audio_decode(hidden_states_select, hidden_states_out_len)[ :, :, 0 @@ -2221,12 +2255,18 @@ class LLMASR5(nn.Module): loss = None + # synthesize waveforms + spk_emb = kwargs.get("spk_emb", None) + feat, wav = self.synthesize_waveform(speech_tokens, spk_emb, inputs_embeds.device) + ibest_writer = None if kwargs.get("output_dir") is not None: if not hasattr(self, "writer"): self.writer = DatadirWriter(kwargs.get("output_dir")) ibest_writer = self.writer[f"{0 + 1}best_recog"] + self.write_mel_wav(kwargs.get("output_dir"), feat, wav, key[0]) + results = [] response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response) result_i = { @@ -2253,6 +2293,48 @@ class LLMASR5(nn.Module): return results, meta_data + def write_mel_wav(self, output_dir, feat, wav, key): + out_dir = os.path.join(output_dir, "1best_recog", "mels") + os.makedirs(out_dir, exist_ok=True) + if feat is not None: + feat = feat.cpu().numpy()[0] + np.save(os.path.join(out_dir, f"{key}.npy"), feat) + + out_dir = os.path.join(output_dir, "1best_recog", "wavs") + os.makedirs(out_dir, exist_ok=True) + if wav is not None: + path = os.path.join(out_dir, f"{key}.wav") + torchaudio.save( + path, wav[0], sample_rate=self.vocoder.sample_rate, + encoding='PCM_S', bits_per_sample=16 + ) + + def synthesize_waveform(self, speech_tokens, spk_emb, device): + mel_feat, wav = None, None + if self.mel_decoder is not None and spk_emb is not None: + # mel_feat in BxCxT + mel_feat = self.token2mel(speech_tokens, spk_emb, device) + if self.vocoder is not None: + wav = self.vocoder.inference(mel_feat.transpose(1, 2)) + + return mel_feat, wav + + def token2mel(self, tokens: torch.Tensor, xvec: torch.Tensor, device: torch.device): + xvec = torch.tensor(xvec).to(device).unsqueeze(0) + xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64) + token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64) + feat = self.mel_decoder.inference( + tokens, token_lens, + xvec, xvec_lens, + diff_steps=10, + temperature=1.0, + prompt=dict( + prompt_text=(None, None), + prompt_audio=(None, None) + ) + ) + return feat + def audio_decode( self, text: torch.Tensor, @@ -2263,9 +2345,8 @@ class LLMASR5(nn.Module): decoding_length=None, ): # 1. encode text - text = self.audio_decoder_in_proj(text) + # text = self.audio_decoder_in_proj(text) device = text.device - out_tokens = [] sos_eos_emb = self.audio_decoder_embedding( torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device) ) @@ -2273,30 +2354,18 @@ class LLMASR5(nn.Module): torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device) ) prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1) - state, cfg_state = None, None + seq_input = torch.zeros( + [1, prompt.shape[1] + max_length, prompt.shape[2]], + dtype=torch.float32, device=device + ) + seq_input[:, :prompt.shape[1], :] = prompt + out_tokens = torch.zeros([1, max_length, 1], device=device) + out_token_len = 0 + prompt_len = prompt.shape[1] + state, hit_eos = None, False for i in range(max_length): - if len(out_tokens) > 0: - codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device) - codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device) - # if any quantizer output is eos - if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)): - break - seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, codec_prompt, codec_lengths, need_targets=False - ) - else: - seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, None, None, need_targets=False - ) - # use state for speedup pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0]) - if infer_cfg_ratio is not None: - cond_len = prompt[0].shape[0] - cfg_pred, (cfg_state, _) = self.audio_decoder.score( - seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :] - ) - pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred # sampling all `nq` token ids pred = pred.reshape(self.predict_nq, -1) @@ -2304,49 +2373,44 @@ class LLMASR5(nn.Module): pred = torch.log_softmax(pred, dim=-1) if min_length is not None and i < min_length: pred[:, self.codebook_size + self.ad_sos_eos] = float(np.finfo(np.float32).min) - top_ids = [] - for k in range(self.predict_nq): - top_ids.append(self.ras_sampling(pred[k], out_tokens)[0].item()) - out_tokens.append(top_ids) + top_ids = self.ras_sampling(pred[0], out_tokens[0]) + out_tokens[0, out_token_len, 0] = top_ids[0] + seq_input[0, prompt_len + out_token_len, :] = self.codec_embedder(top_ids)[0] + out_token_len += 1 - # remove eos token - hit_eos = False - if torch.any( - torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos - ): - hit_eos = True - out_tokens = out_tokens[:-1] + if torch.any(out_tokens[:, out_token_len - 1] == (self.codebook_size + self.ad_sos_eos)): + hit_eos = True + out_tokens = out_tokens[:, :out_token_len, :] + break if decoding_length is None: - return torch.tensor([out_tokens], dtype=torch.int64, device=device) + return out_tokens else: - return torch.tensor([out_tokens], dtype=torch.int64, device=device), hit_eos + return out_tokens, hit_eos # Repetition Aware Sampling in VALL-E 2 def ras_sampling( self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1 ): top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) - rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item() + rep_num = torch.sum(decoded_tokens[-win_size:] == top_ids).item() if rep_num >= win_size * tau_r: top_ids = self.random_sampling(weighted_scores) return top_ids def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25): - prob, indices = [], [] cum_prob = 0.0 sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) + i = len(sorted_idx) for i in range(len(sorted_idx)): # sampling both top-p and numbers. - if cum_prob < top_p and len(prob) < top_k: + if cum_prob < top_p and i < top_k: cum_prob += sorted_value[i] - prob.append(sorted_value[i]) - indices.append(sorted_idx[i]) else: break - prob = torch.tensor(prob).to(weighted_scores) - indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) + prob = sorted_value[:i] + indices = sorted_idx[:i] sampling_ids = prob.multinomial(1, replacement=True) top_ids = indices[sampling_ids] return top_ids diff --git a/funasr/models/specaug/mask_along_axis.py b/funasr/models/specaug/mask_along_axis.py index d4f237b3d..c7a8e2eb0 100644 --- a/funasr/models/specaug/mask_along_axis.py +++ b/funasr/models/specaug/mask_along_axis.py @@ -334,3 +334,45 @@ class MaskAlongAxisLFR(torch.nn.Module): replace_with_zero=self.replace_with_zero, lfr_rate=self.lfr_rate, ) + + +class PrefixMaskVariableMaxWidth(torch.nn.Module): + def __init__( + self, + mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05), + replace_value: float = 0.0, + ): + super().__init__() + self.mask_width_ratio_range = mask_width_ratio_range + self.replace_value = replace_value + + def extra_repr(self): + return ( + f"mask_width_ratio_range={self.mask_width_ratio_range}, " + ) + + def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False): + bb, tt, _ = spec.shape + + mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device) + mask_width_range = (mask_width_ratio_range * tt).long() + mask_length = torch.randint( + mask_width_range[0], + mask_width_range[1], + (bb, 1), + device=spec.device, + ).unsqueeze(2) + + # mask_pos: (B, num_mask, 1) + mask_pos = tt - mask_length + + aran = torch.arange(tt, device=spec.device)[None, None, :] + # mask: (Batch, num_mask, L) + mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) + # Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1) + mask = mask.any(dim=1).unsqueeze(2) + + spec = spec.masked_fill(mask, self.replace_value) + if return_mask: + return spec, spec_lengths, mask + return spec, spec_lengths diff --git a/funasr/utils/hinter.py b/funasr/utils/hinter.py new file mode 100644 index 000000000..8c99809f1 --- /dev/null +++ b/funasr/utils/hinter.py @@ -0,0 +1,13 @@ +import sys +import torch.distributed +import logging + +HINTED = set() + + +def hint_once(content, uid, rank=None): + if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank: + if uid not in HINTED: + logging.info(content) + HINTED.add(uid) +