From cdc70650084f9a69bacd842b7434a008354e2ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AF=AD=E5=B8=86?= Date: Thu, 22 Feb 2024 17:23:20 +0800 Subject: [PATCH] test --- funasr/auto/auto_model.py | 7 +- funasr/models/lcbnet/attention.py | 112 +++++++++ funasr/models/lcbnet/encoder.py | 392 ++++++++++++++++++++++++++++++ 3 files changed, 506 insertions(+), 5 deletions(-) create mode 100644 funasr/models/lcbnet/attention.py create mode 100644 funasr/models/lcbnet/encoder.py diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 7c8630356..a5341eacf 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -161,18 +161,15 @@ class AutoModel: vocab_size = len(tokenizer.token_list) else: vocab_size = -1 - pdb.set_trace() # build frontend frontend = kwargs.get("frontend", None) - pdb.set_trace() + if frontend is not None: - pdb.set_trace() frontend_class = tables.frontend_classes.get(frontend) frontend = frontend_class(**kwargs["frontend_conf"]) - pdb.set_trace() kwargs["frontend"] = frontend kwargs["input_size"] = frontend.output_size() - pdb.set_trace() + # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py new file mode 100644 index 000000000..8e8c5943a --- /dev/null +++ b/funasr/models/lcbnet/attention.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 yufan +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention Return Weight layer definition.""" + +import math + +import torch +from torch import nn + +class MultiHeadedAttentionReturnWeight(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttentionReturnWeight object.""" + super(MultiHeadedAttentionReturnWeight, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = torch.finfo(scores.dtype).min + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x), self.attn # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py new file mode 100644 index 000000000..d2464f1de --- /dev/null +++ b/funasr/models/lcbnet/encoder.py @@ -0,0 +1,392 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer encoder definition.""" + +from typing import List +from typing import Optional +from typing import Tuple + +import torch +from torch import nn +import logging + +from funasr.models.transformer.attention import MultiHeadedAttention +from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight +from funasr.models.transformer.embedding import PositionalEncoding +from funasr.models.transformer.layer_norm import LayerNorm + +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward +from funasr.models.transformer.utils.repeat import repeat +from funasr.register import tables + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, x, mask, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x_q, x, x, mask) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, mask + +@tables.register("encoder_classes", "TransformerTextEncoder") +class TransformerTextEncoder(nn.Module): + """Transformer text encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + super().__init__() + self._output_size = output_size + + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size), + pos_enc_class(output_size, positional_dropout_rate), + ) + + self.normalize_before = normalize_before + + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + MultiHeadedAttention( + attention_heads, output_size, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad = self.embed(xs_pad) + + xs_pad, masks = self.encoders(xs_pad, masks) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None + + + + +@tables.register("encoder_classes", "FusionSANEncoder") +class SelfSrcAttention(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + + """ + def __init__( + self, + size, + attention_heads, + attention_dim, + linear_units, + self_attention_dropout_rate, + src_attention_dropout_rate, + positional_dropout_rate, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an SelfSrcAttention object.""" + super(SelfSrcAttention, self).__init__() + self.size = size + self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) + self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate) + self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate) + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x, score = self.src_attn(x, memory, memory, memory_mask) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask + + + +class ConvPredictor(nn.Module): + def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048): + super().__init__() + self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) + self.norm1 = LayerNorm(size) + self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate) + self.norm2 = LayerNorm(size) + self.pad = nn.ConstantPad1d((l_order, r_order), 0) + self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) + self.output_linear = nn.Linear(size, 1) + + + def forward(self, text_enc, asr_enc): + # stage1 cross-attention + residual = text_enc + text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) + + # stage2 FFN + residual = text_enc + text_enc = self.norm1(text_enc) + text_enc = residual + self.feed_forward(text_enc) + + # stage Conv predictor + text_enc = self.norm2(text_enc) + context = text_enc.transpose(1, 2) + queries = self.pad(context) + memory = self.conv1d(queries) + output = memory + context + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.output_linear(output) + if output.dim()==3: + output = output.squeeze(2) + return output