From 9bed46a31e28095a203806ccdd20932476124b47 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 21 Feb 2024 16:50:49 +0800 Subject: [PATCH] transducer inference --- funasr/models/bat/attention.py | 238 ------- funasr/models/bat/cif_predictor.py | 220 ------ funasr/models/bat/conformer_chunk_encoder.py | 701 ------------------ funasr/models/bat/model.py | 706 ++++++++++--------- 4 files changed, 373 insertions(+), 1492 deletions(-) delete mode 100644 funasr/models/bat/attention.py delete mode 100644 funasr/models/bat/cif_predictor.py delete mode 100644 funasr/models/bat/conformer_chunk_encoder.py diff --git a/funasr/models/bat/attention.py b/funasr/models/bat/attention.py deleted file mode 100644 index 11645b3c8..000000000 --- a/funasr/models/bat/attention.py +++ /dev/null @@ -1,238 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2019 Shigeki Karita -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Multi-Head Attention layer definition.""" - -import math - -import numpy -import torch -from torch import nn -from typing import Optional, Tuple - -import torch.nn.functional as F -from funasr.models.transformer.utils.nets_utils import make_pad_mask -import funasr.models.lora.layers as lora - - -class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): - """RelPositionMultiHeadedAttention definition. - Args: - num_heads: Number of attention heads. - embed_size: Embedding size. - dropout_rate: Dropout rate. - """ - - def __init__( - self, - num_heads: int, - embed_size: int, - dropout_rate: float = 0.0, - simplified_attention_score: bool = False, - ) -> None: - """Construct an MultiHeadedAttention object.""" - super().__init__() - - self.d_k = embed_size // num_heads - self.num_heads = num_heads - - assert self.d_k * num_heads == embed_size, ( - "embed_size (%d) must be divisible by num_heads (%d)", - (embed_size, num_heads), - ) - - self.linear_q = torch.nn.Linear(embed_size, embed_size) - self.linear_k = torch.nn.Linear(embed_size, embed_size) - self.linear_v = torch.nn.Linear(embed_size, embed_size) - - self.linear_out = torch.nn.Linear(embed_size, embed_size) - - if simplified_attention_score: - self.linear_pos = torch.nn.Linear(embed_size, num_heads) - - self.compute_att_score = self.compute_simplified_attention_score - else: - self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) - - self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - self.compute_att_score = self.compute_attention_score - - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.attn = None - - def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: - """Compute relative positional encoding. - Args: - x: Input sequence. (B, H, T_1, 2 * T_1 - 1) - left_context: Number of frames in left context. - Returns: - x: Output sequence. (B, H, T_1, T_2) - """ - batch_size, n_heads, time1, n = x.shape - time2 = time1 + left_context - - batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() - - return x.as_strided( - (batch_size, n_heads, time1, time2), - (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), - storage_offset=(n_stride * (time1 - 1)), - ) - - def compute_simplified_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Simplified attention score computation. - Reference: https://github.com/k2-fsa/icefall/pull/458 - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - pos_enc = self.linear_pos(pos_enc) - - matrix_ac = torch.matmul(query, key.transpose(2, 3)) - - matrix_bd = self.rel_shift( - pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), - left_context=left_context, - ) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def compute_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Attention score computation. - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) - - query = query.transpose(1, 2) - q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) - q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) - - matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) - - matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) - matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transform query, key and value. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - v: Value tensor. (B, T_2, size) - Returns: - q: Transformed query tensor. (B, H, T_1, d_k) - k: Transformed key tensor. (B, H, T_2, d_k) - v: Transformed value tensor. (B, H, T_2, d_k) - """ - n_batch = query.size(0) - - q = ( - self.linear_q(query) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - k = ( - self.linear_k(key) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - v = ( - self.linear_v(value) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - - return q, k, v - - def forward_attention( - self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Compute attention context vector. - Args: - value: Transformed value. (B, H, T_2, d_k) - scores: Attention score. (B, H, T_1, T_2) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - Returns: - attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) - """ - batch_size = scores.size(0) - mask = mask.unsqueeze(1).unsqueeze(2) - if chunk_mask is not None: - mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask - scores = scores.masked_fill(mask, float("-inf")) - self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - - attn_output = self.dropout(self.attn) - attn_output = torch.matmul(attn_output, value) - - attn_output = self.linear_out( - attn_output.transpose(1, 2) - .contiguous() - .view(batch_size, -1, self.num_heads * self.d_k) - ) - - return attn_output - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - left_context: int = 0, - ) -> torch.Tensor: - """Compute scaled dot product attention with rel. positional encoding. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - value: Value tensor. (B, T_2, size) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - left_context: Number of frames in left context. - Returns: - : Output tensor. (B, T_1, H * d_k) - """ - q, k, v = self.forward_qkv(query, key, value) - scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) - return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) - diff --git a/funasr/models/bat/cif_predictor.py b/funasr/models/bat/cif_predictor.py deleted file mode 100644 index d8915c226..000000000 --- a/funasr/models/bat/cif_predictor.py +++ /dev/null @@ -1,220 +0,0 @@ -# import torch -# from torch import nn -# from torch import Tensor -# import logging -# import numpy as np -# from funasr.train_utils.device_funcs import to_device -# from funasr.models.transformer.utils.nets_utils import make_pad_mask -# from funasr.models.scama.utils import sequence_mask -# from typing import Optional, Tuple -# -# from funasr.register import tables -# -# class mae_loss(nn.Module): -# -# def __init__(self, normalize_length=False): -# super(mae_loss, self).__init__() -# self.normalize_length = normalize_length -# self.criterion = torch.nn.L1Loss(reduction='sum') -# -# def forward(self, token_length, pre_token_length): -# loss_token_normalizer = token_length.size(0) -# if self.normalize_length: -# loss_token_normalizer = token_length.sum().type(torch.float32) -# loss = self.criterion(token_length, pre_token_length) -# loss = loss / loss_token_normalizer -# return loss -# -# -# def cif(hidden, alphas, threshold): -# batch_size, len_time, hidden_size = hidden.size() -# -# # loop varss -# integrate = torch.zeros([batch_size], device=hidden.device) -# frame = torch.zeros([batch_size, hidden_size], device=hidden.device) -# # intermediate vars along time -# list_fires = [] -# list_frames = [] -# -# for t in range(len_time): -# alpha = alphas[:, t] -# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate -# -# integrate += alpha -# list_fires.append(integrate) -# -# fire_place = integrate >= threshold -# integrate = torch.where(fire_place, -# integrate - torch.ones([batch_size], device=hidden.device), -# integrate) -# cur = torch.where(fire_place, -# distribution_completion, -# alpha) -# remainds = alpha - cur -# -# frame += cur[:, None] * hidden[:, t, :] -# list_frames.append(frame) -# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), -# remainds[:, None] * hidden[:, t, :], -# frame) -# -# fires = torch.stack(list_fires, 1) -# frames = torch.stack(list_frames, 1) -# list_ls = [] -# len_labels = torch.round(alphas.sum(-1)).int() -# max_label_len = len_labels.max() -# for b in range(batch_size): -# fire = fires[b, :] -# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) -# pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device) -# list_ls.append(torch.cat([l, pad_l], 0)) -# return torch.stack(list_ls, 0), fires -# -# -# def cif_wo_hidden(alphas, threshold): -# batch_size, len_time = alphas.size() -# -# # loop varss -# integrate = torch.zeros([batch_size], device=alphas.device) -# # intermediate vars along time -# list_fires = [] -# -# for t in range(len_time): -# alpha = alphas[:, t] -# -# integrate += alpha -# list_fires.append(integrate) -# -# fire_place = integrate >= threshold -# integrate = torch.where(fire_place, -# integrate - torch.ones([batch_size], device=alphas.device)*threshold, -# integrate) -# -# fires = torch.stack(list_fires, 1) -# return fires -# -# @tables.register("predictor_classes", "BATPredictor") -# class BATPredictor(nn.Module): -# def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False): -# super(BATPredictor, self).__init__() -# -# self.pad = nn.ConstantPad1d((l_order, r_order), 0) -# self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim) -# self.cif_output = nn.Linear(idim, 1) -# self.dropout = torch.nn.Dropout(p=dropout) -# self.threshold = threshold -# self.smooth_factor = smooth_factor -# self.noise_threshold = noise_threshold -# self.return_accum = return_accum -# -# def cif( -# self, -# input: Tensor, -# alpha: Tensor, -# beta: float = 1.0, -# return_accum: bool = False, -# ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: -# B, S, C = input.size() -# assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}" -# -# dtype = alpha.dtype -# alpha = alpha.float() -# -# alpha_sum = alpha.sum(1) -# feat_lengths = (alpha_sum / beta).floor().long() -# T = feat_lengths.max() -# -# # aggregate and integrate -# csum = alpha.cumsum(-1) -# with torch.no_grad(): -# # indices used for scattering -# right_idx = (csum / beta).floor().long().clip(max=T) -# left_idx = right_idx.roll(1, dims=1) -# left_idx[:, 0] = 0 -# -# # count # of fires from each source -# fire_num = right_idx - left_idx -# extra_weights = (fire_num - 1).clip(min=0) -# # The extra entry in last dim is for -# output = input.new_zeros((B, T + 1, C)) -# source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input) -# zero = alpha.new_zeros((1,)) -# -# # right scatter -# fire_mask = fire_num > 0 -# right_weight = torch.where( -# fire_mask, -# csum - right_idx.type_as(alpha) * beta, -# zero -# ).type_as(input) -# # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative." -# output.scatter_add_( -# 1, -# right_idx.unsqueeze(-1).expand(-1, -1, C), -# right_weight.unsqueeze(-1) * input -# ) -# -# # left scatter -# left_weight = ( -# alpha - right_weight - extra_weights.type_as(alpha) * beta -# ).type_as(input) -# output.scatter_add_( -# 1, -# left_idx.unsqueeze(-1).expand(-1, -1, C), -# left_weight.unsqueeze(-1) * input -# ) -# -# # extra scatters -# if extra_weights.ge(0).any(): -# extra_steps = extra_weights.max().item() -# tgt_idx = left_idx -# src_feats = input * beta -# for _ in range(extra_steps): -# tgt_idx = (tgt_idx + 1).clip(max=T) -# # (B, S, 1) -# src_mask = (extra_weights > 0) -# output.scatter_add_( -# 1, -# tgt_idx.unsqueeze(-1).expand(-1, -1, C), -# src_feats * src_mask.unsqueeze(2) -# ) -# extra_weights -= 1 -# -# output = output[:, :T, :] -# -# if return_accum: -# return output, csum -# else: -# return output, alpha -# -# def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None): -# h = hidden -# context = h.transpose(1, 2) -# queries = self.pad(context) -# memory = self.cif_conv1d(queries) -# output = memory + context -# output = self.dropout(output) -# output = output.transpose(1, 2) -# output = torch.relu(output) -# output = self.cif_output(output) -# alphas = torch.sigmoid(output) -# alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold) -# if mask is not None: -# alphas = alphas * mask.transpose(-1, -2).float() -# if mask_chunk_predictor is not None: -# alphas = alphas * mask_chunk_predictor -# alphas = alphas.squeeze(-1) -# if target_label_length is not None: -# target_length = target_label_length -# elif target_label is not None: -# target_length = (target_label != ignore_id).float().sum(-1) -# # logging.info("target_length: {}".format(target_length)) -# else: -# target_length = None -# token_num = alphas.sum(-1) -# if target_length is not None: -# # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5 -# # target_length = length_noise + target_length -# alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1)) -# acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum) -# return acoustic_embeds, token_num, alphas, cif_peak diff --git a/funasr/models/bat/conformer_chunk_encoder.py b/funasr/models/bat/conformer_chunk_encoder.py deleted file mode 100644 index 7635c0289..000000000 --- a/funasr/models/bat/conformer_chunk_encoder.py +++ /dev/null @@ -1,701 +0,0 @@ - -"""Conformer encoder definition.""" - -import logging -from typing import Union, Dict, List, Tuple, Optional - -import torch -from torch import nn - - -from funasr.models.bat.attention import ( - RelPositionMultiHeadedAttentionChunk, -) -from funasr.models.transformer.embedding import ( - StreamingRelPositionalEncoding, -) -from funasr.models.transformer.layer_norm import LayerNorm -from funasr.models.transformer.utils.nets_utils import get_activation -from funasr.models.transformer.utils.nets_utils import ( - TooShortUttError, - check_short_utt, - make_chunk_mask, - make_source_mask, -) -from funasr.models.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, -) -from funasr.models.transformer.utils.repeat import repeat, MultiBlocks -from funasr.models.transformer.utils.subsampling import TooShortUttError -from funasr.models.transformer.utils.subsampling import check_short_utt -from funasr.models.transformer.utils.subsampling import StreamingConvInput -from funasr.register import tables - - - -class ChunkEncoderLayer(nn.Module): - """Chunk Conformer module definition. - Args: - block_size: Input/output size. - self_att: Self-attention module instance. - feed_forward: Feed-forward module instance. - feed_forward_macaron: Feed-forward module instance for macaron network. - conv_mod: Convolution module instance. - norm_class: Normalization module class. - norm_args: Normalization module arguments. - dropout_rate: Dropout rate. - """ - - def __init__( - self, - block_size: int, - self_att: torch.nn.Module, - feed_forward: torch.nn.Module, - feed_forward_macaron: torch.nn.Module, - conv_mod: torch.nn.Module, - norm_class: torch.nn.Module = LayerNorm, - norm_args: Dict = {}, - dropout_rate: float = 0.0, - ) -> None: - """Construct a Conformer object.""" - super().__init__() - - self.self_att = self_att - - self.feed_forward = feed_forward - self.feed_forward_macaron = feed_forward_macaron - self.feed_forward_scale = 0.5 - - self.conv_mod = conv_mod - - self.norm_feed_forward = norm_class(block_size, **norm_args) - self.norm_self_att = norm_class(block_size, **norm_args) - - self.norm_macaron = norm_class(block_size, **norm_args) - self.norm_conv = norm_class(block_size, **norm_args) - self.norm_final = norm_class(block_size, **norm_args) - - self.dropout = torch.nn.Dropout(dropout_rate) - - self.block_size = block_size - self.cache = None - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset self-attention and convolution modules cache for streaming. - Args: - left_context: Number of left frames during chunk-by-chunk inference. - device: Device to use for cache tensor. - """ - self.cache = [ - torch.zeros( - (1, left_context, self.block_size), - device=device, - ), - torch.zeros( - ( - 1, - self.block_size, - self.conv_mod.kernel_size - 1, - ), - device=device, - ), - ] - - def forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Encode input sequences. - Args: - x: Conformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T) - chunk_mask: Chunk mask. (T_2, T_2) - Returns: - x: Conformer output sequences. (B, T, D_block) - mask: Source mask. (B, T) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - """ - residual = x - - x = self.norm_macaron(x) - x = residual + self.feed_forward_scale * self.dropout( - self.feed_forward_macaron(x) - ) - - residual = x - x = self.norm_self_att(x) - x_q = x - x = residual + self.dropout( - self.self_att( - x_q, - x, - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - ) - - residual = x - - x = self.norm_conv(x) - x, _ = self.conv_mod(x) - x = residual + self.dropout(x) - residual = x - - x = self.norm_feed_forward(x) - x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) - - x = self.norm_final(x) - return x, mask, pos_enc - - def chunk_forward( - self, - x: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_size: int = 16, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode chunk of input sequence. - Args: - x: Conformer input sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - mask: Source mask. (B, T_2) - left_context: Number of frames in left context. - right_context: Number of frames in right context. - Returns: - x: Conformer output sequences. (B, T, D_block) - pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) - """ - residual = x - - x = self.norm_macaron(x) - x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) - - residual = x - x = self.norm_self_att(x) - if left_context > 0: - key = torch.cat([self.cache[0], x], dim=1) - else: - key = x - val = key - - if right_context > 0: - att_cache = key[:, -(left_context + right_context) : -right_context, :] - else: - att_cache = key[:, -left_context:, :] - x = residual + self.self_att( - x, - key, - val, - pos_enc, - mask, - left_context=left_context, - ) - - residual = x - x = self.norm_conv(x) - x, conv_cache = self.conv_mod( - x, cache=self.cache[1], right_context=right_context - ) - x = residual + x - residual = x - - x = self.norm_feed_forward(x) - x = residual + self.feed_forward_scale * self.feed_forward(x) - - x = self.norm_final(x) - self.cache = [att_cache, conv_cache] - - return x, pos_enc - - - -class CausalConvolution(nn.Module): - """ConformerConvolution module definition. - Args: - channels: The number of channels. - kernel_size: Size of the convolving kernel. - activation: Type of activation function. - norm_args: Normalization module arguments. - causal: Whether to use causal convolution (set to True if streaming). - """ - - def __init__( - self, - channels: int, - kernel_size: int, - activation: torch.nn.Module = torch.nn.ReLU(), - norm_args: Dict = {}, - causal: bool = False, - ) -> None: - """Construct an ConformerConvolution object.""" - super().__init__() - - assert (kernel_size - 1) % 2 == 0 - - self.kernel_size = kernel_size - - self.pointwise_conv1 = torch.nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - ) - - if causal: - self.lorder = kernel_size - 1 - padding = 0 - else: - self.lorder = 0 - padding = (kernel_size - 1) // 2 - - self.depthwise_conv = torch.nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - ) - self.norm = torch.nn.BatchNorm1d(channels, **norm_args) - self.pointwise_conv2 = torch.nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - ) - - self.activation = activation - - def forward( - self, - x: torch.Tensor, - cache: Optional[torch.Tensor] = None, - right_context: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute convolution module. - Args: - x: ConformerConvolution input sequences. (B, T, D_hidden) - cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) - right_context: Number of frames in right context. - Returns: - x: ConformerConvolution output sequences. (B, T, D_hidden) - cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) - """ - x = self.pointwise_conv1(x.transpose(1, 2)) - x = torch.nn.functional.glu(x, dim=1) - - if self.lorder > 0: - if cache is None: - x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - else: - x = torch.cat([cache, x], dim=2) - - if right_context > 0: - cache = x[:, :, -(self.lorder + right_context) : -right_context] - else: - cache = x[:, :, -self.lorder :] - - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x).transpose(1, 2) - - return x, cache - -@tables.register("encoder_classes", "ConformerChunkEncoder") -class ConformerChunkEncoder(nn.Module): - """Encoder module definition. - Args: - input_size: Input size. - body_conf: Encoder body configuration. - input_conf: Encoder input configuration. - main_conf: Encoder main configuration. - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - embed_vgg_like: bool = False, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 3, - macaron_style: bool = False, - rel_pos_type: str = "legacy", - pos_enc_layer_type: str = "rel_pos", - selfattention_layer_type: str = "rel_selfattn", - activation_type: str = "swish", - use_cnn_module: bool = True, - zero_triu: bool = False, - norm_type: str = "layer_norm", - cnn_module_kernel: int = 31, - conv_mod_norm_eps: float = 0.00001, - conv_mod_norm_momentum: float = 0.1, - simplified_att_score: bool = False, - dynamic_chunk_training: bool = False, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 25, - left_chunk_size: int = 0, - time_reduction_factor: int = 1, - unified_model_training: bool = False, - default_chunk_size: int = 16, - jitter_range: int = 4, - subsampling_factor: int = 1, - ) -> None: - """Construct an Encoder object.""" - super().__init__() - - - self.embed = StreamingConvInput( - input_size, - output_size, - subsampling_factor, - vgg_like=embed_vgg_like, - output_size=output_size, - ) - - self.pos_enc = StreamingRelPositionalEncoding( - output_size, - positional_dropout_rate, - ) - - activation = get_activation( - activation_type - ) - - pos_wise_args = ( - output_size, - linear_units, - positional_dropout_rate, - activation, - ) - - conv_mod_norm_args = { - "eps": conv_mod_norm_eps, - "momentum": conv_mod_norm_momentum, - } - - conv_mod_args = ( - output_size, - cnn_module_kernel, - activation, - conv_mod_norm_args, - dynamic_chunk_training or unified_model_training, - ) - - mult_att_args = ( - attention_heads, - output_size, - attention_dropout_rate, - simplified_att_score, - ) - - - fn_modules = [] - for _ in range(num_blocks): - module = lambda: ChunkEncoderLayer( - output_size, - RelPositionMultiHeadedAttentionChunk(*mult_att_args), - PositionwiseFeedForward(*pos_wise_args), - PositionwiseFeedForward(*pos_wise_args), - CausalConvolution(*conv_mod_args), - dropout_rate=dropout_rate, - ) - fn_modules.append(module) - - self.encoders = MultiBlocks( - [fn() for fn in fn_modules], - output_size, - ) - - self._output_size = output_size - - self.dynamic_chunk_training = dynamic_chunk_training - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - self.left_chunk_size = left_chunk_size - - self.unified_model_training = unified_model_training - self.default_chunk_size = default_chunk_size - self.jitter_range = jitter_range - - self.time_reduction_factor = time_reduction_factor - - def output_size(self) -> int: - return self._output_size - - def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: - """Return the corresponding number of sample for a given chunk size, in frames. - Where size is the number of features frames after applying subsampling. - Args: - size: Number of frames after subsampling. - hop_length: Frontend's hop length - Returns: - : Number of raw samples - """ - return self.embed.get_size_before_subsampling(size) * hop_length - - def get_encoder_input_size(self, size: int) -> int: - """Return the corresponding number of sample for a given chunk size, in frames. - Where size is the number of features frames after applying subsampling. - Args: - size: Number of frames after subsampling. - Returns: - : Number of raw samples - """ - return self.embed.get_size_before_subsampling(size) - - - def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: - """Initialize/Reset encoder streaming cache. - Args: - left_context: Number of frames in left context. - device: Device ID. - """ - return self.encoders.reset_streaming_cache(left_context, device) - - def forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode input sequences. - Args: - x: Encoder input features. (B, T_in, F) - x_len: Encoder input features lengths. (B,) - Returns: - x: Encoder outputs. (B, T_out, D_enc) - x_len: Encoder outputs lenghts. (B,) - """ - short_status, limit_size = check_short_utt( - self.embed.subsampling_factor, x.size(1) - ) - - if short_status: - raise TooShortUttError( - f"has {x.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - x.size(1), - limit_size, - ) - - mask = make_source_mask(x_len).to(x.device) - - if self.unified_model_training: - if self.training: - chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() - else: - chunk_size = self.default_chunk_size - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - x_utt = self.encoders( - x, - pos_enc, - mask, - chunk_mask=None, - ) - x_chunk = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x_utt = x_utt[:,::self.time_reduction_factor,:] - x_chunk = x_chunk[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - return x_utt, x_chunk, olens - - elif self.dynamic_chunk_training: - max_len = x.size(1) - if self.training: - chunk_size = torch.randint(1, max_len, (1,)).item() - - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = (chunk_size % self.short_chunk_size) + 1 - else: - chunk_size = self.default_chunk_size - - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - else: - x, mask = self.embed(x, mask, None) - pos_enc = self.pos_enc(x) - chunk_mask = None - x = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 - - return x, olens, None - - def full_utt_forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode input sequences. - Args: - x: Encoder input features. (B, T_in, F) - x_len: Encoder input features lengths. (B,) - Returns: - x: Encoder outputs. (B, T_out, D_enc) - x_len: Encoder outputs lenghts. (B,) - """ - short_status, limit_size = check_short_utt( - self.embed.subsampling_factor, x.size(1) - ) - - if short_status: - raise TooShortUttError( - f"has {x.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - x.size(1), - limit_size, - ) - - mask = make_source_mask(x_len).to(x.device) - x, mask = self.embed(x, mask, None) - pos_enc = self.pos_enc(x) - x_utt = self.encoders( - x, - pos_enc, - mask, - chunk_mask=None, - ) - - if self.time_reduction_factor > 1: - x_utt = x_utt[:,::self.time_reduction_factor,:] - return x_utt - - def simu_chunk_forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - chunk_size: int = 16, - left_context: int = 32, - right_context: int = 0, - ) -> torch.Tensor: - short_status, limit_size = check_short_utt( - self.embed.subsampling_factor, x.size(1) - ) - - if short_status: - raise TooShortUttError( - f"has {x.size(1)} frames and is too short for subsampling " - + f"(it needs more than {limit_size} frames), return empty results", - x.size(1), - limit_size, - ) - - mask = make_source_mask(x_len) - - x, mask = self.embed(x, mask, chunk_size) - pos_enc = self.pos_enc(x) - chunk_mask = make_chunk_mask( - x.size(1), - chunk_size, - left_chunk_size=self.left_chunk_size, - device=x.device, - ) - - x = self.encoders( - x, - pos_enc, - mask, - chunk_mask=chunk_mask, - ) - olens = mask.eq(0).sum(1) - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - - return x - - def chunk_forward( - self, - x: torch.Tensor, - x_len: torch.Tensor, - processed_frames: torch.tensor, - chunk_size: int = 16, - left_context: int = 32, - right_context: int = 0, - ) -> torch.Tensor: - """Encode input sequences as chunks. - Args: - x: Encoder input features. (1, T_in, F) - x_len: Encoder input features lengths. (1,) - processed_frames: Number of frames already seen. - left_context: Number of frames in left context. - right_context: Number of frames in right context. - Returns: - x: Encoder outputs. (B, T_out, D_enc) - """ - mask = make_source_mask(x_len) - x, mask = self.embed(x, mask, None) - - if left_context > 0: - processed_mask = ( - torch.arange(left_context, device=x.device) - .view(1, left_context) - .flip(1) - ) - processed_mask = processed_mask >= processed_frames - mask = torch.cat([processed_mask, mask], dim=1) - pos_enc = self.pos_enc(x, left_context=left_context) - x = self.encoders.chunk_forward( - x, - pos_enc, - mask, - chunk_size=chunk_size, - left_context=left_context, - right_context=right_context, - ) - - if right_context > 0: - x = x[:, 0:-right_context, :] - - if self.time_reduction_factor > 1: - x = x[:,::self.time_reduction_factor,:] - return x diff --git a/funasr/models/bat/model.py b/funasr/models/bat/model.py index 3fed9aa55..8e76b458a 100644 --- a/funasr/models/bat/model.py +++ b/funasr/models/bat/model.py @@ -3,137 +3,145 @@ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) - +import time import torch import logging -import torch.nn as nn +from contextlib import contextmanager +from typing import Dict, Optional, Tuple +from distutils.version import LooseVersion -from typing import Dict, List, Optional, Tuple, Union - - -from torch.cuda.amp import autocast -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) - -from funasr.models.transformer.utils.nets_utils import get_transducer_task_io -from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.register import tables +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter from funasr.train_utils.device_funcs import force_gatherable +from funasr.models.transformer.scorers.ctc import CTCPrefixScorer +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.transformer.scorers.length_bonus import LengthBonus +from funasr.models.transformer.utils.nets_utils import get_transducer_task_io +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield - -class BATModel(nn.Module): - """BATModel module definition. - - Args: - vocab_size: Size of complete vocabulary (w/ EOS and blank included). - token_list: List of token - frontend: Frontend module. - specaug: SpecAugment module. - normalize: Normalization module. - encoder: Encoder module. - decoder: Decoder module. - joint_network: Joint Network module. - transducer_weight: Weight of the Transducer loss. - fastemit_lambda: FastEmit lambda value. - auxiliary_ctc_weight: Weight of auxiliary CTC loss. - auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. - auxiliary_lm_loss_weight: Weight of auxiliary LM loss. - auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. - ignore_id: Initial padding ID. - sym_space: Space symbol. - sym_blank: Blank Symbol - report_cer: Whether to report Character Error Rate during validation. - report_wer: Whether to report Word Error Rate during validation. - extract_feats_in_collect_stats: Whether to use extract_feats stats collection. - - """ - +@tables.register("model_classes", "BAT") # TODO: BAT training +class BAT(torch.nn.Module): def __init__( self, - - cif_weight: float = 1.0, + frontend: Optional[str] = None, + frontend_conf: Optional[Dict] = None, + specaug: Optional[str] = None, + specaug_conf: Optional[Dict] = None, + normalize: str = None, + normalize_conf: Optional[Dict] = None, + encoder: str = None, + encoder_conf: Optional[Dict] = None, + decoder: str = None, + decoder_conf: Optional[Dict] = None, + joint_network: str = None, + joint_network_conf: Optional[Dict] = None, + transducer_weight: float = 1.0, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, auxiliary_ctc_dropout_rate: float = 0.0, auxiliary_lm_loss_weight: float = 0.0, auxiliary_lm_loss_smoothing: float = 0.0, + input_size: int = 80, + vocab_size: int = -1, ignore_id: int = -1, - sym_space: str = "", - sym_blank: str = "", - report_cer: bool = True, - report_wer: bool = True, - extract_feats_in_collect_stats: bool = True, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, lsm_weight: float = 0.0, length_normalized_loss: bool = False, - r_d: int = 5, - r_u: int = 5, + # report_cer: bool = True, + # report_wer: bool = True, + # sym_space: str = "", + # sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, **kwargs, - ) -> None: - """Construct an BATModel object.""" + ): + super().__init__() - # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) - self.blank_id = 0 + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class( + vocab_size=vocab_size, + **decoder_conf, + ) + decoder_output_size = decoder.output_size + + joint_network_class = tables.joint_network_classes.get(joint_network) + joint_network = joint_network_class( + vocab_size, + encoder_output_size, + decoder_output_size, + **joint_network_conf, + ) + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id - self.token_list = token_list.copy() - - self.sym_space = sym_space - self.sym_blank = sym_blank - self.frontend = frontend self.specaug = specaug self.normalize = normalize - self.encoder = encoder self.decoder = decoder self.joint_network = joint_network - self.criterion_transducer = None - self.error_calculator = None - - self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 - self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 - - if self.use_auxiliary_ctc: - self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) - self.ctc_dropout_rate = auxiliary_ctc_dropout_rate - - if self.use_auxiliary_lm_loss: - self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) - self.lm_loss_smoothing = auxiliary_lm_loss_smoothing - - self.transducer_weight = transducer_weight - self.fastemit_lambda = fastemit_lambda - - self.auxiliary_ctc_weight = auxiliary_ctc_weight - self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight - - self.report_cer = report_cer - self.report_wer = report_wer - - self.extract_feats_in_collect_stats = extract_feats_in_collect_stats - - self.criterion_pre = torch.nn.L1Loss() - self.predictor_weight = predictor_weight - self.predictor = predictor - - self.cif_weight = cif_weight - if self.cif_weight > 0: - self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size) - self.criterion_cif = LabelSmoothingLoss( - size=vocab_size, - padding_idx=ignore_id, - smoothing=lsm_weight, - normalize_length=length_normalized_loss, - ) - self.r_d = r_d - self.r_u = r_u + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + self.ctc = None + self.ctc_weight = 0.0 + def forward( self, speech: torch.Tensor, @@ -142,111 +150,167 @@ class BATModel(nn.Module): text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Forward architecture and compute loss(es). - + """Encoder + Decoder + Calc loss Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - loss: Main loss value. - stats: Task statistics. - weight: Task weights. - + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) """ - assert text_lengths.dim() == 1, text_lengths.shape - assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + batch_size = speech.shape[0] - text = text[:, : text_lengths.max()] - # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) - - encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device) # 2. Transducer-related I/O preparation decoder_in, target, t_len, u_len = get_transducer_task_io( text, encoder_out_lens, ignore_id=self.ignore_id, ) - + # 3. Decoder self.decoder.set_device(encoder_out.device) decoder_out = self.decoder(decoder_in, u_len) - - pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id) - loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length) - - if self.cif_weight > 0.0: - cif_predict = self.cif_output_layer(pre_acoustic_embeds) - loss_cif = self.criterion_cif(cif_predict, text) - else: - loss_cif = 0.0 - - # 5. Losses - boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device) - boundary[:, 2] = u_len.long().detach() - boundary[:, 3] = t_len.long().detach() - - pre_peak_index = torch.floor(pre_peak_index).long() - s_begin = pre_peak_index - self.r_d - - T = encoder_out.size(1) - B = encoder_out.size(0) - U = decoder_out.size(1) - - mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T) - mask = mask <= boundary[:, 3].reshape(B, 1) - 1 - - s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 - # handle the cases where `len(symbols) < s_range` - s_begin_padding = torch.clamp(s_begin_padding, min=0) - - s_begin = torch.where(mask, s_begin, s_begin_padding) - mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 - - s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1) - - s_begin = torch.clamp(s_begin, min=0) - - ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device) - - import fast_rnnt - am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning( - am=self.joint_network.lin_enc(encoder_out), - lm=self.joint_network.lin_dec(decoder_out), - ranges=ranges, + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) ) - - logits = self.joint_network(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - loss_trans = fast_rnnt.rnnt_loss_pruned( - logits=logits.float(), - symbols=target.long(), - ranges=ranges, - termination_symbol=self.blank_id, - boundary=boundary, - reduction="sum", + + # 5. Losses + loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_ctc, loss_lm = 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight - cer_trans, wer_trans = None, None + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + if self.criterion_transducer is None: + try: + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + if not self.training and (self.report_cer or self.report_wer): if self.error_calculator is None: from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator + self.error_calculator = ErrorCalculator( self.decoder, self.joint_network, @@ -256,149 +320,13 @@ class BATModel(nn.Module): report_cer=self.report_cer, report_wer=self.report_wer, ) - cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len) - - loss_ctc, loss_lm = 0.0, 0.0 - - if self.use_auxiliary_ctc: - loss_ctc = self._calc_ctc_loss( - encoder_out, - target, - t_len, - u_len, - ) - - if self.use_auxiliary_lm_loss: - loss_lm = self._calc_lm_loss(decoder_out, target) - - loss = ( - self.transducer_weight * loss_trans - + self.auxiliary_ctc_weight * loss_ctc - + self.auxiliary_lm_loss_weight * loss_lm - + self.predictor_weight * loss_pre - + self.cif_weight * loss_cif - ) - - stats = dict( - loss=loss.detach(), - loss_transducer=loss_trans.detach(), - loss_pre=loss_pre.detach(), - loss_cif=loss_cif.detach() if loss_cif > 0.0 else None, - aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, - aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, - cer_transducer=cer_trans, - wer_transducer=wer_trans, - ) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) - - return loss, stats, weight - - def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - **kwargs, - ) -> Dict[str, torch.Tensor]: - """Collect features sequences and features lengths sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - text: Label ID sequences. (B, L) - text_lengths: Label ID sequences lengths. (B,) - kwargs: Contains "utts_id". - - Return: - {}: "feats": Features sequences. (B, T, D_feats), - "feats_lengths": Features sequences lengths. (B,) - - """ - if self.extract_feats_in_collect_stats: - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - else: - # Generate dummy stats if extract_feats_in_collect_stats is False - logging.warning( - "Generating dummy stats for feats and feats_lengths, " - "because encoder_conf.extract_feats_in_collect_stats is " - f"{self.extract_feats_in_collect_stats}" - ) - - feats, feats_lengths = speech, speech_lengths - - return {"feats": feats, "feats_lengths": feats_lengths} - - def encode( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encoder speech sequences. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - encoder_out: Encoder outputs. (B, T, D_enc) - encoder_out_lens: Encoder outputs lengths. (B,) - - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # 4. Forward encoder - encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - return encoder_out, encoder_out_lens - - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Extract features sequences and features sequences lengths. - - Args: - speech: Speech sequences. (B, S) - speech_lengths: Speech sequences lengths. (B,) - - Return: - feats: Features sequences. (B, T, D_feats) - feats_lengths: Features sequences lengths. (B,) - - """ - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - feats, feats_lengths = speech, speech_lengths - - return feats, feats_lengths - + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) + + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + def _calc_ctc_loss( self, encoder_out: torch.Tensor, @@ -422,10 +350,10 @@ class BATModel(nn.Module): torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) ) ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) - + target_mask = target != 0 ctc_target = target[target_mask].cpu() - + with torch.backends.cudnn.flags(deterministic=True): loss_ctc = torch.nn.functional.ctc_loss( ctc_in, @@ -436,9 +364,9 @@ class BATModel(nn.Module): reduction="sum", ) loss_ctc /= target.size(0) - + return loss_ctc - + def _calc_lm_loss( self, decoder_out: torch.Tensor, @@ -456,17 +384,17 @@ class BATModel(nn.Module): """ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) lm_target = target.view(-1).type(torch.int64) - + with torch.no_grad(): true_dist = lm_loss_in.clone() true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) - + # Ignore blank ID (0) ignore = lm_target == 0 lm_target = lm_target.masked_fill(ignore, 0) - + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) - + loss_lm = torch.nn.functional.kl_div( torch.log_softmax(lm_loss_in, dim=1), true_dist, @@ -475,5 +403,117 @@ class BATModel(nn.Module): loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( 0 ) - + return loss_lm + + def init_beam_search(self, + **kwargs, + ): + + # 1. Build ASR model + scorers = {} + + if self.ctc != None: + ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) + scorers.update( + ctc=ctc + ) + token_list = kwargs.get("token_list") + scorers.update( + length_bonus=LengthBonus(len(token_list)), + ) + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + beam_search = BeamSearchTransducer( + self.decoder, + self.joint_network, + kwargs.get("beam_size", 2), + nbest=1, + ) + # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + # for scorer in scorers.values(): + # if isinstance(scorer, torch.nn.Module): + # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + self.beam_search = beam_search + + def inference(self, + data_in: list, + data_lengths: list=None, + key: list=None, + tokenizer=None, + **kwargs, + ): + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + # if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search(encoder_out[0], is_final=True) + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + b, n, d = encoder_out.size() + for i in range(b): + + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq#[1:last_pos] + else: + token_int = hyp.yseq#[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text + ibest_writer["text_postprocessed"][key[i]] = text_postprocessed + + return results, meta_data +