transducer inference

This commit is contained in:
shixian.shi 2024-02-21 16:50:49 +08:00
parent 509bc88903
commit 9bed46a31e
4 changed files with 373 additions and 1492 deletions

View File

@ -1,238 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import numpy
import torch
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import funasr.models.lora.layers as lora
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
"""RelPositionMultiHeadedAttention definition.
Args:
num_heads: Number of attention heads.
embed_size: Embedding size.
dropout_rate: Dropout rate.
"""
def __init__(
self,
num_heads: int,
embed_size: int,
dropout_rate: float = 0.0,
simplified_attention_score: bool = False,
) -> None:
"""Construct an MultiHeadedAttention object."""
super().__init__()
self.d_k = embed_size // num_heads
self.num_heads = num_heads
assert self.d_k * num_heads == embed_size, (
"embed_size (%d) must be divisible by num_heads (%d)",
(embed_size, num_heads),
)
self.linear_q = torch.nn.Linear(embed_size, embed_size)
self.linear_k = torch.nn.Linear(embed_size, embed_size)
self.linear_v = torch.nn.Linear(embed_size, embed_size)
self.linear_out = torch.nn.Linear(embed_size, embed_size)
if simplified_attention_score:
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
self.compute_att_score = self.compute_simplified_attention_score
else:
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
self.compute_att_score = self.compute_attention_score
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.attn = None
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
left_context: Number of frames in left context.
Returns:
x: Output sequence. (B, H, T_1, T_2)
"""
batch_size, n_heads, time1, n = x.shape
time2 = time1 + left_context
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
return x.as_strided(
(batch_size, n_heads, time1, time2),
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
storage_offset=(n_stride * (time1 - 1)),
)
def compute_simplified_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Simplified attention score computation.
Reference: https://github.com/k2-fsa/icefall/pull/458
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
pos_enc = self.linear_pos(pos_enc)
matrix_ac = torch.matmul(query, key.transpose(2, 3))
matrix_bd = self.rel_shift(
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
left_context=left_context,
)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def compute_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Attention score computation.
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
v: Value tensor. (B, T_2, size)
Returns:
q: Transformed query tensor. (B, H, T_1, d_k)
k: Transformed key tensor. (B, H, T_2, d_k)
v: Transformed value tensor. (B, H, T_2, d_k)
"""
n_batch = query.size(0)
q = (
self.linear_q(query)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
k = (
self.linear_k(key)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
v = (
self.linear_v(value)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value: Transformed value. (B, H, T_2, d_k)
scores: Attention score. (B, H, T_1, T_2)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
Returns:
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
"""
batch_size = scores.size(0)
mask = mask.unsqueeze(1).unsqueeze(2)
if chunk_mask is not None:
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
scores = scores.masked_fill(mask, float("-inf"))
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn_output = self.dropout(self.attn)
attn_output = torch.matmul(attn_output, value)
attn_output = self.linear_out(
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, -1, self.num_heads * self.d_k)
)
return attn_output
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
left_context: int = 0,
) -> torch.Tensor:
"""Compute scaled dot product attention with rel. positional encoding.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
value: Value tensor. (B, T_2, size)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
left_context: Number of frames in left context.
Returns:
: Output tensor. (B, T_1, H * d_k)
"""
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)

View File

@ -1,220 +0,0 @@
# import torch
# from torch import nn
# from torch import Tensor
# import logging
# import numpy as np
# from funasr.train_utils.device_funcs import to_device
# from funasr.models.transformer.utils.nets_utils import make_pad_mask
# from funasr.models.scama.utils import sequence_mask
# from typing import Optional, Tuple
#
# from funasr.register import tables
#
# class mae_loss(nn.Module):
#
# def __init__(self, normalize_length=False):
# super(mae_loss, self).__init__()
# self.normalize_length = normalize_length
# self.criterion = torch.nn.L1Loss(reduction='sum')
#
# def forward(self, token_length, pre_token_length):
# loss_token_normalizer = token_length.size(0)
# if self.normalize_length:
# loss_token_normalizer = token_length.sum().type(torch.float32)
# loss = self.criterion(token_length, pre_token_length)
# loss = loss / loss_token_normalizer
# return loss
#
#
# def cif(hidden, alphas, threshold):
# batch_size, len_time, hidden_size = hidden.size()
#
# # loop varss
# integrate = torch.zeros([batch_size], device=hidden.device)
# frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# # intermediate vars along time
# list_fires = []
# list_frames = []
#
# for t in range(len_time):
# alpha = alphas[:, t]
# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
#
# integrate += alpha
# list_fires.append(integrate)
#
# fire_place = integrate >= threshold
# integrate = torch.where(fire_place,
# integrate - torch.ones([batch_size], device=hidden.device),
# integrate)
# cur = torch.where(fire_place,
# distribution_completion,
# alpha)
# remainds = alpha - cur
#
# frame += cur[:, None] * hidden[:, t, :]
# list_frames.append(frame)
# frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
# remainds[:, None] * hidden[:, t, :],
# frame)
#
# fires = torch.stack(list_fires, 1)
# frames = torch.stack(list_frames, 1)
# list_ls = []
# len_labels = torch.round(alphas.sum(-1)).int()
# max_label_len = len_labels.max()
# for b in range(batch_size):
# fire = fires[b, :]
# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
# pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
# list_ls.append(torch.cat([l, pad_l], 0))
# return torch.stack(list_ls, 0), fires
#
#
# def cif_wo_hidden(alphas, threshold):
# batch_size, len_time = alphas.size()
#
# # loop varss
# integrate = torch.zeros([batch_size], device=alphas.device)
# # intermediate vars along time
# list_fires = []
#
# for t in range(len_time):
# alpha = alphas[:, t]
#
# integrate += alpha
# list_fires.append(integrate)
#
# fire_place = integrate >= threshold
# integrate = torch.where(fire_place,
# integrate - torch.ones([batch_size], device=alphas.device)*threshold,
# integrate)
#
# fires = torch.stack(list_fires, 1)
# return fires
#
# @tables.register("predictor_classes", "BATPredictor")
# class BATPredictor(nn.Module):
# def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
# super(BATPredictor, self).__init__()
#
# self.pad = nn.ConstantPad1d((l_order, r_order), 0)
# self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
# self.cif_output = nn.Linear(idim, 1)
# self.dropout = torch.nn.Dropout(p=dropout)
# self.threshold = threshold
# self.smooth_factor = smooth_factor
# self.noise_threshold = noise_threshold
# self.return_accum = return_accum
#
# def cif(
# self,
# input: Tensor,
# alpha: Tensor,
# beta: float = 1.0,
# return_accum: bool = False,
# ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# B, S, C = input.size()
# assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
#
# dtype = alpha.dtype
# alpha = alpha.float()
#
# alpha_sum = alpha.sum(1)
# feat_lengths = (alpha_sum / beta).floor().long()
# T = feat_lengths.max()
#
# # aggregate and integrate
# csum = alpha.cumsum(-1)
# with torch.no_grad():
# # indices used for scattering
# right_idx = (csum / beta).floor().long().clip(max=T)
# left_idx = right_idx.roll(1, dims=1)
# left_idx[:, 0] = 0
#
# # count # of fires from each source
# fire_num = right_idx - left_idx
# extra_weights = (fire_num - 1).clip(min=0)
# # The extra entry in last dim is for
# output = input.new_zeros((B, T + 1, C))
# source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
# zero = alpha.new_zeros((1,))
#
# # right scatter
# fire_mask = fire_num > 0
# right_weight = torch.where(
# fire_mask,
# csum - right_idx.type_as(alpha) * beta,
# zero
# ).type_as(input)
# # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
# output.scatter_add_(
# 1,
# right_idx.unsqueeze(-1).expand(-1, -1, C),
# right_weight.unsqueeze(-1) * input
# )
#
# # left scatter
# left_weight = (
# alpha - right_weight - extra_weights.type_as(alpha) * beta
# ).type_as(input)
# output.scatter_add_(
# 1,
# left_idx.unsqueeze(-1).expand(-1, -1, C),
# left_weight.unsqueeze(-1) * input
# )
#
# # extra scatters
# if extra_weights.ge(0).any():
# extra_steps = extra_weights.max().item()
# tgt_idx = left_idx
# src_feats = input * beta
# for _ in range(extra_steps):
# tgt_idx = (tgt_idx + 1).clip(max=T)
# # (B, S, 1)
# src_mask = (extra_weights > 0)
# output.scatter_add_(
# 1,
# tgt_idx.unsqueeze(-1).expand(-1, -1, C),
# src_feats * src_mask.unsqueeze(2)
# )
# extra_weights -= 1
#
# output = output[:, :T, :]
#
# if return_accum:
# return output, csum
# else:
# return output, alpha
#
# def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
# h = hidden
# context = h.transpose(1, 2)
# queries = self.pad(context)
# memory = self.cif_conv1d(queries)
# output = memory + context
# output = self.dropout(output)
# output = output.transpose(1, 2)
# output = torch.relu(output)
# output = self.cif_output(output)
# alphas = torch.sigmoid(output)
# alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
# if mask is not None:
# alphas = alphas * mask.transpose(-1, -2).float()
# if mask_chunk_predictor is not None:
# alphas = alphas * mask_chunk_predictor
# alphas = alphas.squeeze(-1)
# if target_label_length is not None:
# target_length = target_label_length
# elif target_label is not None:
# target_length = (target_label != ignore_id).float().sum(-1)
# # logging.info("target_length: {}".format(target_length))
# else:
# target_length = None
# token_num = alphas.sum(-1)
# if target_length is not None:
# # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
# # target_length = length_noise + target_length
# alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
# acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
# return acoustic_embeds, token_num, alphas, cif_peak

View File

@ -1,701 +0,0 @@
"""Conformer encoder definition."""
import logging
from typing import Union, Dict, List, Tuple, Optional
import torch
from torch import nn
from funasr.models.bat.attention import (
RelPositionMultiHeadedAttentionChunk,
)
from funasr.models.transformer.embedding import (
StreamingRelPositionalEncoding,
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.nets_utils import get_activation
from funasr.models.transformer.utils.nets_utils import (
TooShortUttError,
check_short_utt,
make_chunk_mask,
make_source_mask,
)
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
class ChunkEncoderLayer(nn.Module):
"""Chunk Conformer module definition.
Args:
block_size: Input/output size.
self_att: Self-attention module instance.
feed_forward: Feed-forward module instance.
feed_forward_macaron: Feed-forward module instance for macaron network.
conv_mod: Convolution module instance.
norm_class: Normalization module class.
norm_args: Normalization module arguments.
dropout_rate: Dropout rate.
"""
def __init__(
self,
block_size: int,
self_att: torch.nn.Module,
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
norm_class: torch.nn.Module = LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
"""Construct a Conformer object."""
super().__init__()
self.self_att = self_att
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.feed_forward_scale = 0.5
self.conv_mod = conv_mod
self.norm_feed_forward = norm_class(block_size, **norm_args)
self.norm_self_att = norm_class(block_size, **norm_args)
self.norm_macaron = norm_class(block_size, **norm_args)
self.norm_conv = norm_class(block_size, **norm_args)
self.norm_final = norm_class(block_size, **norm_args)
self.dropout = torch.nn.Dropout(dropout_rate)
self.block_size = block_size
self.cache = None
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset self-attention and convolution modules cache for streaming.
Args:
left_context: Number of left frames during chunk-by-chunk inference.
device: Device to use for cache tensor.
"""
self.cache = [
torch.zeros(
(1, left_context, self.block_size),
device=device,
),
torch.zeros(
(
1,
self.block_size,
self.conv_mod.kernel_size - 1,
),
device=device,
),
]
def forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T)
chunk_mask: Chunk mask. (T_2, T_2)
Returns:
x: Conformer output sequences. (B, T, D_block)
mask: Source mask. (B, T)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.dropout(
self.feed_forward_macaron(x)
)
residual = x
x = self.norm_self_att(x)
x_q = x
x = residual + self.dropout(
self.self_att(
x_q,
x,
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
)
residual = x
x = self.norm_conv(x)
x, _ = self.conv_mod(x)
x = residual + self.dropout(x)
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
x = self.norm_final(x)
return x, mask, pos_enc
def chunk_forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_size: int = 16,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk of input sequence.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T_2)
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Conformer output sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
residual = x
x = self.norm_self_att(x)
if left_context > 0:
key = torch.cat([self.cache[0], x], dim=1)
else:
key = x
val = key
if right_context > 0:
att_cache = key[:, -(left_context + right_context) : -right_context, :]
else:
att_cache = key[:, -left_context:, :]
x = residual + self.self_att(
x,
key,
val,
pos_enc,
mask,
left_context=left_context,
)
residual = x
x = self.norm_conv(x)
x, conv_cache = self.conv_mod(
x, cache=self.cache[1], right_context=right_context
)
x = residual + x
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.feed_forward(x)
x = self.norm_final(x)
self.cache = [att_cache, conv_cache]
return x, pos_enc
class CausalConvolution(nn.Module):
"""ConformerConvolution module definition.
Args:
channels: The number of channels.
kernel_size: Size of the convolving kernel.
activation: Type of activation function.
norm_args: Normalization module arguments.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
channels: int,
kernel_size: int,
activation: torch.nn.Module = torch.nn.ReLU(),
norm_args: Dict = {},
causal: bool = False,
) -> None:
"""Construct an ConformerConvolution object."""
super().__init__()
assert (kernel_size - 1) % 2 == 0
self.kernel_size = kernel_size
self.pointwise_conv1 = torch.nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
)
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
cache: Optional[torch.Tensor] = None,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: ConformerConvolution input sequences. (B, T, D_hidden)
cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
right_context: Number of frames in right context.
Returns:
x: ConformerConvolution output sequences. (B, T, D_hidden)
cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
"""
x = self.pointwise_conv1(x.transpose(1, 2))
x = torch.nn.functional.glu(x, dim=1)
if self.lorder > 0:
if cache is None:
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
x = torch.cat([cache, x], dim=2)
if right_context > 0:
cache = x[:, :, -(self.lorder + right_context) : -right_context]
else:
cache = x[:, :, -self.lorder :]
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x).transpose(1, 2)
return x, cache
@tables.register("encoder_classes", "ConformerChunkEncoder")
class ConformerChunkEncoder(nn.Module):
"""Encoder module definition.
Args:
input_size: Input size.
body_conf: Encoder body configuration.
input_conf: Encoder input configuration.
main_conf: Encoder main configuration.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
embed_vgg_like: bool = False,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
norm_type: str = "layer_norm",
cnn_module_kernel: int = 31,
conv_mod_norm_eps: float = 0.00001,
conv_mod_norm_momentum: float = 0.1,
simplified_att_score: bool = False,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
left_chunk_size: int = 0,
time_reduction_factor: int = 1,
unified_model_training: bool = False,
default_chunk_size: int = 16,
jitter_range: int = 4,
subsampling_factor: int = 1,
) -> None:
"""Construct an Encoder object."""
super().__init__()
self.embed = StreamingConvInput(
input_size,
output_size,
subsampling_factor,
vgg_like=embed_vgg_like,
output_size=output_size,
)
self.pos_enc = StreamingRelPositionalEncoding(
output_size,
positional_dropout_rate,
)
activation = get_activation(
activation_type
)
pos_wise_args = (
output_size,
linear_units,
positional_dropout_rate,
activation,
)
conv_mod_norm_args = {
"eps": conv_mod_norm_eps,
"momentum": conv_mod_norm_momentum,
}
conv_mod_args = (
output_size,
cnn_module_kernel,
activation,
conv_mod_norm_args,
dynamic_chunk_training or unified_model_training,
)
mult_att_args = (
attention_heads,
output_size,
attention_dropout_rate,
simplified_att_score,
)
fn_modules = []
for _ in range(num_blocks):
module = lambda: ChunkEncoderLayer(
output_size,
RelPositionMultiHeadedAttentionChunk(*mult_att_args),
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
CausalConvolution(*conv_mod_args),
dropout_rate=dropout_rate,
)
fn_modules.append(module)
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
output_size,
)
self._output_size = output_size
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.left_chunk_size = left_chunk_size
self.unified_model_training = unified_model_training
self.default_chunk_size = default_chunk_size
self.jitter_range = jitter_range
self.time_reduction_factor = time_reduction_factor
def output_size(self) -> int:
return self._output_size
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
hop_length: Frontend's hop length
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size) * hop_length
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size)
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset encoder streaming cache.
Args:
left_context: Number of frames in left context.
device: Device ID.
"""
return self.encoders.reset_streaming_cache(left_context, device)
def forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
if self.training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
x_chunk = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
x_chunk = x_chunk[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x_utt, x_chunk, olens
elif self.dynamic_chunk_training:
max_len = x.size(1)
if self.training:
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
else:
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
chunk_mask = None
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens, None
def full_utt_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len).to(x.device)
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
return x_utt
def simu_chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x
def chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
processed_frames: torch.tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
"""Encode input sequences as chunks.
Args:
x: Encoder input features. (1, T_in, F)
x_len: Encoder input features lengths. (1,)
processed_frames: Number of frames already seen.
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Encoder outputs. (B, T_out, D_enc)
"""
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, None)
if left_context > 0:
processed_mask = (
torch.arange(left_context, device=x.device)
.view(1, left_context)
.flip(1)
)
processed_mask = processed_mask >= processed_frames
mask = torch.cat([processed_mask, mask], dim=1)
pos_enc = self.pos_enc(x, left_context=left_context)
x = self.encoders.chunk_forward(
x,
pos_enc,
mask,
chunk_size=chunk_size,
left_context=left_context,
right_context=right_context,
)
if right_context > 0:
x = x[:, 0:-right_context, :]
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x

View File

@ -3,137 +3,145 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import time
import torch
import logging
import torch.nn as nn
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
from typing import Dict, List, Optional, Tuple, Union
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.scorers.length_bonus import LengthBonus
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class BATModel(nn.Module):
"""BATModel module definition.
Args:
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
token_list: List of token
frontend: Frontend module.
specaug: SpecAugment module.
normalize: Normalization module.
encoder: Encoder module.
decoder: Decoder module.
joint_network: Joint Network module.
transducer_weight: Weight of the Transducer loss.
fastemit_lambda: FastEmit lambda value.
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
ignore_id: Initial padding ID.
sym_space: Space symbol.
sym_blank: Blank Symbol
report_cer: Whether to report Character Error Rate during validation.
report_wer: Whether to report Word Error Rate during validation.
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
"""
@tables.register("model_classes", "BAT") # TODO: BAT training
class BAT(torch.nn.Module):
def __init__(
self,
cif_weight: float = 1.0,
frontend: Optional[str] = None,
frontend_conf: Optional[Dict] = None,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
joint_network: str = None,
joint_network_conf: Optional[Dict] = None,
transducer_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
auxiliary_ctc_dropout_rate: float = 0.0,
auxiliary_lm_loss_weight: float = 0.0,
auxiliary_lm_loss_smoothing: float = 0.0,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
report_cer: bool = True,
report_wer: bool = True,
extract_feats_in_collect_stats: bool = True,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
r_d: int = 5,
r_u: int = 5,
# report_cer: bool = True,
# report_wer: bool = True,
# sym_space: str = "<space>",
# sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
) -> None:
"""Construct an BATModel object."""
):
super().__init__()
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
**decoder_conf,
)
decoder_output_size = decoder.output_size
joint_network_class = tables.joint_network_classes.get(joint_network)
joint_network = joint_network_class(
vocab_size,
encoder_output_size,
decoder_output_size,
**joint_network_conf,
)
self.criterion_transducer = None
self.error_calculator = None
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
if self.use_auxiliary_ctc:
self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
if self.use_auxiliary_lm_loss:
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
self.transducer_weight = transducer_weight
self.fastemit_lambda = fastemit_lambda
self.auxiliary_ctc_weight = auxiliary_ctc_weight
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.token_list = token_list.copy()
self.sym_space = sym_space
self.sym_blank = sym_blank
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.decoder = decoder
self.joint_network = joint_network
self.criterion_transducer = None
self.error_calculator = None
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
if self.use_auxiliary_ctc:
self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
if self.use_auxiliary_lm_loss:
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
self.transducer_weight = transducer_weight
self.fastemit_lambda = fastemit_lambda
self.auxiliary_ctc_weight = auxiliary_ctc_weight
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
self.report_cer = report_cer
self.report_wer = report_wer
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.criterion_pre = torch.nn.L1Loss()
self.predictor_weight = predictor_weight
self.predictor = predictor
self.cif_weight = cif_weight
if self.cif_weight > 0:
self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
self.criterion_cif = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.r_d = r_d
self.r_u = r_u
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.ctc = None
self.ctc_weight = 0.0
def forward(
self,
speech: torch.Tensor,
@ -142,111 +150,167 @@ class BATModel(nn.Module):
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Forward architecture and compute loss(es).
"""Encoder + Decoder + Calc loss
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
text: Label ID sequences. (B, L)
text_lengths: Label ID sequences lengths. (B,)
kwargs: Contains "utts_id".
Return:
loss: Main loss value.
stats: Task statistics.
weight: Task weights.
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
chunk_outs=None)
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
# 2. Transducer-related I/O preparation
decoder_in, target, t_len, u_len = get_transducer_task_io(
text,
encoder_out_lens,
ignore_id=self.ignore_id,
)
# 3. Decoder
self.decoder.set_device(encoder_out.device)
decoder_out = self.decoder(decoder_in, u_len)
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
if self.cif_weight > 0.0:
cif_predict = self.cif_output_layer(pre_acoustic_embeds)
loss_cif = self.criterion_cif(cif_predict, text)
else:
loss_cif = 0.0
# 5. Losses
boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
boundary[:, 2] = u_len.long().detach()
boundary[:, 3] = t_len.long().detach()
pre_peak_index = torch.floor(pre_peak_index).long()
s_begin = pre_peak_index - self.r_d
T = encoder_out.size(1)
B = encoder_out.size(0)
U = decoder_out.size(1)
mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
mask = mask <= boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
# handle the cases where `len(symbols) < s_range`
s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
s_begin = torch.clamp(s_begin, min=0)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
import fast_rnnt
am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
am=self.joint_network.lin_enc(encoder_out),
lm=self.joint_network.lin_dec(decoder_out),
ranges=ranges,
# 4. Joint Network
joint_out = self.joint_network(
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
loss_trans = fast_rnnt.rnnt_loss_pruned(
logits=logits.float(),
symbols=target.long(),
ranges=ranges,
termination_symbol=self.blank_id,
boundary=boundary,
reduction="sum",
# 5. Losses
loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
encoder_out,
joint_out,
target,
t_len,
u_len,
)
loss_ctc, loss_lm = 0.0, 0.0
if self.use_auxiliary_ctc:
loss_ctc = self._calc_ctc_loss(
encoder_out,
target,
t_len,
u_len,
)
if self.use_auxiliary_lm_loss:
loss_lm = self._calc_lm_loss(decoder_out, target)
loss = (
self.transducer_weight * loss_trans
+ self.auxiliary_ctc_weight * loss_ctc
+ self.auxiliary_lm_loss_weight * loss_lm
)
stats = dict(
loss=loss.detach(),
loss_transducer=loss_trans.detach(),
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
cer_transducer=cer_trans,
wer_transducer=wer_trans,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
cer_trans, wer_trans = None, None
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _calc_transducer_loss(
self,
encoder_out: torch.Tensor,
joint_out: torch.Tensor,
target: torch.Tensor,
t_len: torch.Tensor,
u_len: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
"""Compute Transducer loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
joint_out: Joint Network output sequences (B, T, U, D_joint)
target: Target label ID sequences. (B, L)
t_len: Encoder output sequences lengths. (B,)
u_len: Target label ID sequences lengths. (B,)
Return:
loss_transducer: Transducer loss value.
cer_transducer: Character error rate for Transducer.
wer_transducer: Word Error Rate for Transducer.
"""
if self.criterion_transducer is None:
try:
from warp_rnnt import rnnt_loss as RNNTLoss
self.criterion_transducer = RNNTLoss
except ImportError:
logging.error(
"warp-rnnt was not installed."
"Please consult the installation documentation."
)
exit(1)
log_probs = torch.log_softmax(joint_out, dim=-1)
loss_transducer = self.criterion_transducer(
log_probs,
target,
t_len,
u_len,
reduction="mean",
blank=self.blank_id,
fastemit_lambda=self.fastemit_lambda,
gather=True,
)
if not self.training and (self.report_cer or self.report_wer):
if self.error_calculator is None:
from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
self.error_calculator = ErrorCalculator(
self.decoder,
self.joint_network,
@ -256,149 +320,13 @@ class BATModel(nn.Module):
report_cer=self.report_cer,
report_wer=self.report_wer,
)
cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
loss_ctc, loss_lm = 0.0, 0.0
if self.use_auxiliary_ctc:
loss_ctc = self._calc_ctc_loss(
encoder_out,
target,
t_len,
u_len,
)
if self.use_auxiliary_lm_loss:
loss_lm = self._calc_lm_loss(decoder_out, target)
loss = (
self.transducer_weight * loss_trans
+ self.auxiliary_ctc_weight * loss_ctc
+ self.auxiliary_lm_loss_weight * loss_lm
+ self.predictor_weight * loss_pre
+ self.cif_weight * loss_cif
)
stats = dict(
loss=loss.detach(),
loss_transducer=loss_trans.detach(),
loss_pre=loss_pre.detach(),
loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
cer_transducer=cer_trans,
wer_transducer=wer_trans,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Collect features sequences and features lengths sequences.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
text: Label ID sequences. (B, L)
text_lengths: Label ID sequences lengths. (B,)
kwargs: Contains "utts_id".
Return:
{}: "feats": Features sequences. (B, T, D_feats),
"feats_lengths": Features sequences lengths. (B,)
"""
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder speech sequences.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
Return:
encoder_out: Encoder outputs. (B, T, D_enc)
encoder_out_lens: Encoder outputs lengths. (B,)
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract features sequences and features sequences lengths.
Args:
speech: Speech sequences. (B, S)
speech_lengths: Speech sequences lengths. (B,)
Return:
feats: Features sequences. (B, T, D_feats)
feats_lengths: Features sequences lengths. (B,)
"""
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
return loss_transducer, cer_transducer, wer_transducer
return loss_transducer, None, None
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
@ -422,10 +350,10 @@ class BATModel(nn.Module):
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
)
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
target_mask = target != 0
ctc_target = target[target_mask].cpu()
with torch.backends.cudnn.flags(deterministic=True):
loss_ctc = torch.nn.functional.ctc_loss(
ctc_in,
@ -436,9 +364,9 @@ class BATModel(nn.Module):
reduction="sum",
)
loss_ctc /= target.size(0)
return loss_ctc
def _calc_lm_loss(
self,
decoder_out: torch.Tensor,
@ -456,17 +384,17 @@ class BATModel(nn.Module):
"""
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
lm_target = target.view(-1).type(torch.int64)
with torch.no_grad():
true_dist = lm_loss_in.clone()
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
# Ignore blank ID (0)
ignore = lm_target == 0
lm_target = lm_target.masked_fill(ignore, 0)
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
loss_lm = torch.nn.functional.kl_div(
torch.log_softmax(lm_loss_in, dim=1),
true_dist,
@ -475,5 +403,117 @@ class BATModel(nn.Module):
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
0
)
return loss_lm
def init_beam_search(self,
**kwargs,
):
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(
ctc=ctc
)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
beam_search = BeamSearchTransducer(
self.decoder,
self.joint_network,
kwargs.get("beam_size", 2),
nbest=1,
)
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
# for scorer in scorers.values():
# if isinstance(scorer, torch.nn.Module):
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search
def inference(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
# if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq#[1:last_pos]
else:
token_int = hyp.yseq#[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
return results, meta_data