FunASR/funasr/models/paraformer_online/sanm_decoder.py
2023-12-21 14:20:21 +08:00

508 lines
19 KiB
Python

from typing import List
from typing import Tuple
import logging
import torch
import torch.nn as nn
import numpy as np
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class DecoderLayerSANM(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayerSANM, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
if self_attn is not None:
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
# tgt = self.dropout(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
x, _ = self.self_attn(tgt, tgt_mask)
x = residual + self.dropout(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
# tgt = self.dropout(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
if self.training:
cache = None
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + self.dropout(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
x = residual + self.dropout(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
x = residual + x
return x, memory, fsmn_cache, opt_cache
@tables.register("decoder_classes", "ParaformerSANMDecoder")
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
lora_list: List[str] = None,
lora_rank: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
chunk_multiply_factor: tuple = (1,),
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
# pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
self.decoders2 = repeat(
num_blocks - att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoders3 = repeat(
1,
lambda lnum: DecoderLayerSANM(
attention_dim,
None,
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.chunk_multiply_factor = chunk_multiply_factor
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
chunk_mask: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def score(self, ys, state, x):
"""Score."""
ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
def forward_chunk(
self,
memory: torch.Tensor,
tgt: torch.Tensor,
cache: dict = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
x = tgt
if cache["decode_fsmn"] is None:
cache_layer_num = len(self.decoders)
if self.decoders2 is not None:
cache_layer_num += len(self.decoders2)
fsmn_cache = [None] * cache_layer_num
else:
fsmn_cache = cache["decode_fsmn"]
if cache["opt"] is None:
cache_layer_num = len(self.decoders)
opt_cache = [None] * cache_layer_num
else:
opt_cache = cache["opt"]
for i in range(self.att_layer_num):
decoder = self.decoders[i]
x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
)
if self.num_blocks - self.att_layer_num > 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
decoder = self.decoders2[i]
x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
x, memory, fsmn_cache=fsmn_cache[j]
)
for decoder in self.decoders3:
x, memory, _, _ = decoder.forward_chunk(
x, memory
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
cache["decode_fsmn"] = fsmn_cache
if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
cache["opt"] = opt_cache
return x
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache_layer_num = len(self.decoders)
if self.decoders2 is not None:
cache_layer_num += len(self.decoders2)
cache = [None] * cache_layer_num
new_cache = []
# for c, decoder in zip(cache, self.decoders):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
if self.num_blocks - self.att_layer_num > 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=None
)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache