diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py new file mode 100644 index 000000000..5ef83813e --- /dev/null +++ b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", model_revision="v2.0.1") + +inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益" +vads = inputs.split("|") +rec_result_all = "outputs: " +cache = {} +for vad in vads: + rec_result = model(input=vad, cache=cache) + print(rec_result) + rec_result_all += rec_result[0]['text'] + +print(rec_result_all) diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh b/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh new file mode 100644 index 000000000..fa92a6e98 --- /dev/null +++ b/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh @@ -0,0 +1,10 @@ + +model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727" +model_revision="v2.0.1" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt" \ ++output_dir="./outputs/debug" \ ++device="cpu" diff --git a/funasr/models/ct_transformer/attention.py b/funasr/models/ct_transformer/attention.py deleted file mode 100644 index a35ddee57..000000000 --- a/funasr/models/ct_transformer/attention.py +++ /dev/null @@ -1,1091 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2019 Shigeki Karita -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Multi-Head Attention layer definition.""" - -import math - -import numpy -import torch -from torch import nn -from typing import Optional, Tuple - -import torch.nn.functional as F -from funasr.models.transformer.utils.nets_utils import make_pad_mask -import funasr.models.lora.layers as lora - -class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, n_feat, dropout_rate): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttention, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, query, key, value): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, query, key, value, mask): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask) - - -class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding (old version). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - Paper: https://arxiv.org/abs/1901.02860 - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - self.zero_triu = zero_triu - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x): - """Compute relative positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, head, time1, time2). - - Returns: - torch.Tensor: Output tensor. - - """ - zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x) - - if self.zero_triu: - ones = torch.ones((x.size(2), x.size(3))) - x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] - - return x - - def forward(self, query, key, value, pos_emb, mask): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, time1) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k - ) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask) - - -class RelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding (new implementation). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - Paper: https://arxiv.org/abs/1901.02860 - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - self.zero_triu = zero_triu - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x): - """Compute relative positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - torch.Tensor: Output tensor. - - """ - zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x)[ - :, :, :, : x.size(-1) // 2 + 1 - ] # only keep the positions from 0 to time2 - - if self.zero_triu: - ones = torch.ones((x.size(2), x.size(3)), device=x.device) - x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] - - return x - - def forward(self, query, key, value, pos_emb, mask): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor - (#batch, 2*time1-1, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, 2*time1-1) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k - ) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask) - - -class MultiHeadedAttentionSANM(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionSANM, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - # self.linear_q = nn.Linear(n_feat, n_feat) - # self.linear_k = nn.Linear(n_feat, n_feat) - # self.linear_v = nn.Linear(n_feat, n_feat) - if lora_list is not None: - if "o" in lora_list: - self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list] - if lora_qkv_list == [False, False, False]: - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - else: - self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False) - # padding - left_padding = (kernel_size - 1) // 2 - if sanm_shfit > 0: - left_padding = left_padding + sanm_shfit - right_padding = kernel_size - 1 - left_padding - self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) - - def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): - b, t, d = inputs.size() - if mask is not None: - mask = torch.reshape(mask, (b, -1, 1)) - if mask_shfit_chunk is not None: - mask = mask * mask_shfit_chunk - inputs = inputs * mask - - x = inputs.transpose(1, 2) - x = self.pad_fn(x) - x = self.fsmn_block(x) - x = x.transpose(1, 2) - x += inputs - x = self.dropout(x) - if mask is not None: - x = x * mask - return x - - def forward_qkv(self, x): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - b, t, d = x.size() - q_k_v = self.linear_q_k_v(x) - q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) - q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - return q_h, k_h, v_h, v - - def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - if mask_att_chunk_encoder is not None: - mask = mask * mask_att_chunk_encoder - - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) - return att_outs + fsmn_memory - - def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - if chunk_size is not None and look_back > 0 or look_back == -1: - if cache is not None: - k_h_stride = k_h[:, :, :-(chunk_size[2]), :] - v_h_stride = v_h[:, :, :-(chunk_size[2]), :] - k_h = torch.cat((cache["k"], k_h), dim=2) - v_h = torch.cat((cache["v"], v_h), dim=2) - - cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) - cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) - if look_back != -1: - cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :] - cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :] - else: - cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :], - "v": v_h[:, :, :-(chunk_size[2]), :]} - cache = cache_tmp - fsmn_memory = self.forward_fsmn(v, None) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, None) - return att_outs + fsmn_memory, cache - - -class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - q_h, k_h, v_h, v = self.forward_qkv(x) - fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) - return att_outs + fsmn_memory - -class MultiHeadedAttentionSANMDecoder(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionSANMDecoder, self).__init__() - - self.dropout = nn.Dropout(p=dropout_rate) - - self.fsmn_block = nn.Conv1d(n_feat, n_feat, - kernel_size, stride=1, padding=0, groups=n_feat, bias=False) - # padding - # padding - left_padding = (kernel_size - 1) // 2 - if sanm_shfit > 0: - left_padding = left_padding + sanm_shfit - right_padding = kernel_size - 1 - left_padding - self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) - self.kernel_size = kernel_size - - def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None): - ''' - :param x: (#batch, time1, size). - :param mask: Mask tensor (#batch, 1, time) - :return: - ''' - # print("in fsmn, inputs", inputs.size()) - b, t, d = inputs.size() - # logging.info( - # "mask: {}".format(mask.size())) - if mask is not None: - mask = torch.reshape(mask, (b ,-1, 1)) - # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :])) - if mask_shfit_chunk is not None: - # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :])) - mask = mask * mask_shfit_chunk - # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :])) - # print("in fsmn, mask", mask.size()) - # print("in fsmn, inputs", inputs.size()) - inputs = inputs * mask - - x = inputs.transpose(1, 2) - b, d, t = x.size() - if cache is None: - # print("in fsmn, cache is None, x", x.size()) - - x = self.pad_fn(x) - if not self.training: - cache = x - else: - # print("in fsmn, cache is not None, x", x.size()) - # x = torch.cat((x, cache), dim=2)[:, :, :-1] - # if t < self.kernel_size: - # x = self.pad_fn(x) - x = torch.cat((cache[:, :, 1:], x), dim=2) - x = x[:, :, -(self.kernel_size+t-1):] - # print("in fsmn, cache is not None, x_cat", x.size()) - cache = x - x = self.fsmn_block(x) - x = x.transpose(1, 2) - # print("in fsmn, fsmn_out", x.size()) - if x.size(1) != inputs.size(1): - inputs = inputs[:, -1, :] - - x = x + inputs - x = self.dropout(x) - if mask is not None: - x = x * mask - return x, cache - -class MultiHeadedAttentionCrossAtt(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionCrossAtt, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - if lora_list is not None: - if "q" in lora_list: - self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_q = nn.Linear(n_feat, n_feat) - lora_kv_list = ["k" in lora_list, "v" in lora_list] - if lora_kv_list == [False, False]: - self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2) - else: - self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, - r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list) - if "o" in lora_list: - self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - else: - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, x, memory): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - - # print("in forward_qkv, x", x.size()) - b = x.size(0) - q = self.linear_q(x) - q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - - k_v = self.linear_k_v(memory) - k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1) - k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - - return q_h, k_h, v_h - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - # logging.info( - # "scores: {}, mask_size: {}".format(scores.size(), mask.size())) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, memory, memory_mask): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h = self.forward_qkv(x, memory) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, memory_mask) - - def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h = self.forward_qkv(x, memory) - if chunk_size is not None and look_back > 0: - if cache is not None: - k_h = torch.cat((cache["k"], k_h), dim=2) - v_h = torch.cat((cache["v"], v_h), dim=2) - cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :] - cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :] - else: - cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :], - "v": v_h[:, :, -(look_back * chunk_size[1]):, :]} - cache = cache_tmp - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, None), cache - - -class MultiHeadSelfAttention(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, in_feat, n_feat, dropout_rate): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadSelfAttention, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_out = nn.Linear(n_feat, n_feat) - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, x): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - b, t, d = x.size() - q_k_v = self.linear_q_k_v(x) - q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) - q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - return q_h, k_h, v_h, v - - def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - if mask_att_chunk_encoder is not None: - mask = mask * mask_att_chunk_encoder - - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, mask, mask_att_chunk_encoder=None): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) - return att_outs - -class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): - """RelPositionMultiHeadedAttention definition. - Args: - num_heads: Number of attention heads. - embed_size: Embedding size. - dropout_rate: Dropout rate. - """ - - def __init__( - self, - num_heads: int, - embed_size: int, - dropout_rate: float = 0.0, - simplified_attention_score: bool = False, - ) -> None: - """Construct an MultiHeadedAttention object.""" - super().__init__() - - self.d_k = embed_size // num_heads - self.num_heads = num_heads - - assert self.d_k * num_heads == embed_size, ( - "embed_size (%d) must be divisible by num_heads (%d)", - (embed_size, num_heads), - ) - - self.linear_q = torch.nn.Linear(embed_size, embed_size) - self.linear_k = torch.nn.Linear(embed_size, embed_size) - self.linear_v = torch.nn.Linear(embed_size, embed_size) - - self.linear_out = torch.nn.Linear(embed_size, embed_size) - - if simplified_attention_score: - self.linear_pos = torch.nn.Linear(embed_size, num_heads) - - self.compute_att_score = self.compute_simplified_attention_score - else: - self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) - - self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - self.compute_att_score = self.compute_attention_score - - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.attn = None - - def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: - """Compute relative positional encoding. - Args: - x: Input sequence. (B, H, T_1, 2 * T_1 - 1) - left_context: Number of frames in left context. - Returns: - x: Output sequence. (B, H, T_1, T_2) - """ - batch_size, n_heads, time1, n = x.shape - time2 = time1 + left_context - - batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() - - return x.as_strided( - (batch_size, n_heads, time1, time2), - (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), - storage_offset=(n_stride * (time1 - 1)), - ) - - def compute_simplified_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Simplified attention score computation. - Reference: https://github.com/k2-fsa/icefall/pull/458 - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - pos_enc = self.linear_pos(pos_enc) - - matrix_ac = torch.matmul(query, key.transpose(2, 3)) - - matrix_bd = self.rel_shift( - pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), - left_context=left_context, - ) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def compute_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Attention score computation. - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) - - query = query.transpose(1, 2) - q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) - q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) - - matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) - - matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) - matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transform query, key and value. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - v: Value tensor. (B, T_2, size) - Returns: - q: Transformed query tensor. (B, H, T_1, d_k) - k: Transformed key tensor. (B, H, T_2, d_k) - v: Transformed value tensor. (B, H, T_2, d_k) - """ - n_batch = query.size(0) - - q = ( - self.linear_q(query) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - k = ( - self.linear_k(key) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - v = ( - self.linear_v(value) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - - return q, k, v - - def forward_attention( - self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Compute attention context vector. - Args: - value: Transformed value. (B, H, T_2, d_k) - scores: Attention score. (B, H, T_1, T_2) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - Returns: - attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) - """ - batch_size = scores.size(0) - mask = mask.unsqueeze(1).unsqueeze(2) - if chunk_mask is not None: - mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask - scores = scores.masked_fill(mask, float("-inf")) - self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - - attn_output = self.dropout(self.attn) - attn_output = torch.matmul(attn_output, value) - - attn_output = self.linear_out( - attn_output.transpose(1, 2) - .contiguous() - .view(batch_size, -1, self.num_heads * self.d_k) - ) - - return attn_output - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - left_context: int = 0, - ) -> torch.Tensor: - """Compute scaled dot product attention with rel. positional encoding. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - value: Value tensor. (B, T_2, size) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - left_context: Number of frames in left context. - Returns: - : Output tensor. (B, T_1, H * d_k) - """ - q, k, v = self.forward_qkv(query, key, value) - scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) - return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) - - -class CosineDistanceAttention(nn.Module): - """ Compute Cosine Distance between spk decoder output and speaker profile - Args: - profile_path: speaker profile file path (.npy file) - """ - - def __init__(self): - super().__init__() - self.softmax = nn.Softmax(dim=-1) - - def forward(self, spk_decoder_out, profile, profile_lens=None): - """ - Args: - spk_decoder_out(torch.Tensor):(B, L, D) - spk_profiles(torch.Tensor):(B, N, D) - """ - x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) - if profile_lens is not None: - - mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min - ) - weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value) - weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N) - else: - x = x[:, -1:, :, :] - weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1) - weights = self.softmax(weights_not_softmax) # (B, 1, N) - spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) - - return spk_embedding, weights diff --git a/funasr/models/ct_transformer/encoder.py b/funasr/models/ct_transformer/encoder.py deleted file mode 100644 index 784baf37d..000000000 --- a/funasr/models/ct_transformer/encoder.py +++ /dev/null @@ -1,383 +0,0 @@ -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -import logging -import torch -import torch.nn as nn -import torch.nn.functional as F -from funasr.models.scama.chunk_utilis import overlap_chunk -import numpy as np -from funasr.train_utils.device_funcs import to_device -from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.models.sanm.attention import MultiHeadedAttention -from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask -from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder -from funasr.models.transformer.layer_norm import LayerNorm -from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear -from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d -from funasr.models.transformer.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from funasr.models.transformer.utils.repeat import repeat -from funasr.models.transformer.utils.subsampling import Conv2dSubsampling -from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2 -from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6 -from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8 -from funasr.models.transformer.utils.subsampling import TooShortUttError -from funasr.models.transformer.utils.subsampling import check_short_utt -from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask - -from funasr.models.ctc.ctc import CTC - -from funasr.register import tables - -class EncoderLayerSANM(nn.Module): - def __init__( - self, - in_size, - size, - self_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - stochastic_depth_rate=0.0, - ): - """Construct an EncoderLayer object.""" - super(EncoderLayerSANM, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(in_size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.in_size = in_size - self.size = size - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear = nn.Linear(size + size, size) - self.stochastic_depth_rate = stochastic_depth_rate - self.dropout_rate = dropout_rate - - def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute encoded features. - - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - - """ - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - stoch_layer_coeff = 1.0 - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - if cache is not None: - x = torch.cat([cache, x], dim=1) - return x, mask - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if self.concat_after: - x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.concat_linear(x_concat) - else: - x = stoch_layer_coeff * self.concat_linear(x_concat) - else: - if self.in_size == self.size: - x = residual + stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - else: - x = stoch_layer_coeff * self.dropout( - self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) - ) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder - - def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): - """Compute encoded features. - - Args: - x_input (torch.Tensor): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time). - cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time). - - """ - - residual = x - if self.normalize_before: - x = self.norm1(x) - - if self.in_size == self.size: - attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) - x = residual + attn - else: - x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) - - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + self.feed_forward(x) - if not self.normalize_before: - x = self.norm2(x) - - return x, cache - - -@tables.register("encoder_classes", "SANMVadEncoder") -class SANMVadEncoder(nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - - """ - - def __init__( - self, - input_size: int, - output_size: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - pos_enc_class=SinusoidalPositionEncoder, - normalize_before: bool = True, - concat_after: bool = False, - positionwise_layer_type: str = "linear", - positionwise_conv_kernel_size: int = 1, - padding_idx: int = -1, - interctc_layer_idx: List[int] = [], - interctc_use_conditioning: bool = False, - kernel_size : int = 11, - sanm_shfit : int = 0, - selfattention_layer_type: str = "sanm", - ): - super().__init__() - self._output_size = output_size - - if input_layer == "linear": - self.embed = torch.nn.Sequential( - torch.nn.Linear(input_size, output_size), - torch.nn.LayerNorm(output_size), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - pos_enc_class(output_size, positional_dropout_rate), - ) - elif input_layer == "conv2d": - self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) - elif input_layer == "conv2d2": - self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) - elif input_layer == "conv2d6": - self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) - elif input_layer == "conv2d8": - self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) - elif input_layer == "embed": - self.embed = torch.nn.Sequential( - torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), - SinusoidalPositionEncoder(), - ) - elif input_layer is None: - if input_size == output_size: - self.embed = None - else: - self.embed = torch.nn.Linear(input_size, output_size) - elif input_layer == "pe": - self.embed = SinusoidalPositionEncoder() - else: - raise ValueError("unknown input_layer: " + input_layer) - self.normalize_before = normalize_before - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - output_size, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - if selfattention_layer_type == "selfattn": - encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - ) - - elif selfattention_layer_type == "sanm": - self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask - encoder_selfattn_layer_args0 = ( - attention_heads, - input_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - output_size, - attention_dropout_rate, - kernel_size, - sanm_shfit, - ) - - self.encoders0 = repeat( - 1, - lambda lnum: EncoderLayerSANM( - input_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - self.encoders = repeat( - num_blocks-1, - lambda lnum: EncoderLayerSANM( - output_size, - output_size, - self.encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - dropout_rate, - normalize_before, - concat_after, - ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(output_size) - - self.interctc_layer_idx = interctc_layer_idx - if len(interctc_layer_idx) > 0: - assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks - self.interctc_use_conditioning = interctc_use_conditioning - self.conditioning_layer = None - self.dropout = nn.Dropout(dropout_rate) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - vad_indexes: torch.Tensor, - prev_states: torch.Tensor = None, - ctc: CTC = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs_pad: input tensor (B, L, D) - ilens: input length (B) - prev_states: Not to be used now. - Returns: - position embedded tensor and mask - """ - masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) - sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) - no_future_masks = masks & sub_masks - xs_pad *= self.output_size()**0.5 - if self.embed is None: - xs_pad = xs_pad - elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) - or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): - short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) - if short_status: - raise TooShortUttError( - f"has {xs_pad.size(1)} frames and is too short for subsampling " + - f"(it needs more than {limit_size} frames), return empty results", - xs_pad.size(1), - limit_size, - ) - xs_pad, masks = self.embed(xs_pad, masks) - else: - xs_pad = self.embed(xs_pad) - - # xs_pad = self.dropout(xs_pad) - mask_tup0 = [masks, no_future_masks] - encoder_outs = self.encoders0(xs_pad, mask_tup0) - xs_pad, _ = encoder_outs[0], encoder_outs[1] - intermediate_outs = [] - - - for layer_idx, encoder_layer in enumerate(self.encoders): - if layer_idx + 1 == len(self.encoders): - # This is last layer. - coner_mask = torch.ones(masks.size(0), - masks.size(-1), - masks.size(-1), - device=xs_pad.device, - dtype=torch.bool) - for word_index, length in enumerate(ilens): - coner_mask[word_index, :, :] = vad_mask(masks.size(-1), - vad_indexes[word_index], - device=xs_pad.device) - layer_mask = masks & coner_mask - else: - layer_mask = no_future_masks - mask_tup1 = [masks, layer_mask] - encoder_outs = encoder_layer(xs_pad, mask_tup1) - xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] - - if self.normalize_before: - xs_pad = self.after_norm(xs_pad) - - olens = masks.squeeze(1).sum(1) - if len(intermediate_outs) > 0: - return (xs_pad, intermediate_outs), olens, None - return xs_pad, olens, None diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index d84368636..7187f4580 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -60,7 +60,7 @@ class CTTransformer(nn.Module): - def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs): """Compute loss value from buffer sequences. Args: diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py index 917f2e035..c5f85e6cb 100644 --- a/funasr/models/ct_transformer/utils.py +++ b/funasr/models/ct_transformer/utils.py @@ -14,26 +14,6 @@ def split_to_mini_sentence(words: list, word_limit: int = 20): return sentences -# def split_words(text: str, **kwargs): -# words = [] -# segs = text.split() -# for seg in segs: -# # There is no space in seg. -# current_word = "" -# for c in seg: -# if len(c.encode()) == 1: -# # This is an ASCII char. -# current_word += c -# else: -# # This is a Chinese char. -# if len(current_word) > 0: -# words.append(current_word) -# current_word = "" -# words.append(c) -# if len(current_word) > 0: -# words.append(current_word) -# -# return words def split_words(text: str, jieba_usr_dict=None, **kwargs): if jieba_usr_dict: diff --git a/funasr/models/ct_transformer/vad_realtime_transformer.py b/funasr/models/ct_transformer/vad_realtime_transformer.py deleted file mode 100644 index 155057ce8..000000000 --- a/funasr/models/ct_transformer/vad_realtime_transformer.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Any -from typing import List -from typing import Tuple - -import torch -import torch.nn as nn - -from funasr.models.transformer.embedding import SinusoidalPositionEncoder -from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder - - -class VadRealtimeTransformer(torch.nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection - https://arxiv.org/pdf/2003.01309.pdf - """ - def __init__( - self, - vocab_size: int, - punc_size: int, - pos_enc: str = None, - embed_unit: int = 128, - att_unit: int = 256, - head: int = 2, - unit: int = 1024, - layer: int = 4, - dropout_rate: float = 0.5, - kernel_size: int = 11, - sanm_shfit: int = 0, - ): - super().__init__() - if pos_enc == "sinusoidal": - # pos_enc_class = PositionalEncoding - pos_enc_class = SinusoidalPositionEncoder - elif pos_enc is None: - - def pos_enc_class(*args, **kwargs): - return nn.Sequential() # indentity - - else: - raise ValueError(f"unknown pos-enc option: {pos_enc}") - - self.embed = nn.Embedding(vocab_size, embed_unit) - self.encoder = Encoder( - input_size=embed_unit, - output_size=att_unit, - attention_heads=head, - linear_units=unit, - num_blocks=layer, - dropout_rate=dropout_rate, - input_layer="pe", - # pos_enc_class=pos_enc_class, - padding_idx=0, - kernel_size=kernel_size, - sanm_shfit=sanm_shfit, - ) - self.decoder = nn.Linear(att_unit, punc_size) - - -# def _target_mask(self, ys_in_pad): -# ys_mask = ys_in_pad != 0 -# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) -# return ys_mask.unsqueeze(-2) & m - - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, - vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: - """Compute loss value from buffer sequences. - - Args: - input (torch.Tensor): Input ids. (batch, len) - hidden (torch.Tensor): Target ids. (batch, len) - - """ - x = self.embed(input) - # mask = self._target_mask(input) - h, _, _ = self.encoder(x, text_lengths, vad_indexes) - y = self.decoder(h) - return y, None - - def with_vad(self): - return True - - def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: - """Score new token. - - Args: - y (torch.Tensor): 1D torch.int64 prefix tokens. - state: Scorer state for prefix tokens - x (torch.Tensor): encoder feature that generates ys. - - Returns: - tuple[torch.Tensor, Any]: Tuple of - torch.float32 scores for next token (vocab_size) - and next state for ys - - """ - y = y.unsqueeze(0) - h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1).squeeze(0) - return logp, cache - - def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: - """Score new token batch. - - Args: - ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). - states (List[Any]): Scorer states for prefix tokens. - xs (torch.Tensor): - The encoder feature that generates ys (n_batch, xlen, n_feat). - - Returns: - tuple[torch.Tensor, List[Any]]: Tuple of - batchfied scores for next token with shape of `(n_batch, vocab_size)` - and next state list for ys. - - """ - # merge states - n_batch = len(ys) - n_layers = len(self.encoder.encoders) - if states[0] is None: - batch_state = None - else: - # transpose state of [batch, layer] into [layer, batch] - batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] - - # batch decoding - h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1) - - # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] - return logp, state_list diff --git a/funasr/models/ct_transformer_streaming/attention.py b/funasr/models/ct_transformer_streaming/attention.py index a35ddee57..382334ec1 100644 --- a/funasr/models/ct_transformer_streaming/attention.py +++ b/funasr/models/ct_transformer_streaming/attention.py @@ -11,487 +11,12 @@ import math import numpy import torch from torch import nn +import torch.nn.functional as F from typing import Optional, Tuple -import torch.nn.functional as F -from funasr.models.transformer.utils.nets_utils import make_pad_mask -import funasr.models.lora.layers as lora +from funasr.models.sanm.attention import MultiHeadedAttentionSANM -class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer. - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, n_feat, dropout_rate): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttention, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, query, key, value): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, query, key, value, mask): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask) - - -class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding (old version). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - Paper: https://arxiv.org/abs/1901.02860 - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - self.zero_triu = zero_triu - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x): - """Compute relative positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, head, time1, time2). - - Returns: - torch.Tensor: Output tensor. - - """ - zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x) - - if self.zero_triu: - ones = torch.ones((x.size(2), x.size(3))) - x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] - - return x - - def forward(self, query, key, value, pos_emb, mask): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, time1) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k - ) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask) - - -class RelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding (new implementation). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - Paper: https://arxiv.org/abs/1901.02860 - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - self.zero_triu = zero_triu - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x): - """Compute relative positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - torch.Tensor: Output tensor. - - """ - zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x)[ - :, :, :, : x.size(-1) // 2 + 1 - ] # only keep the positions from 0 to time2 - - if self.zero_triu: - ones = torch.ones((x.size(2), x.size(3)), device=x.device) - x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] - - return x - - def forward(self, query, key, value, pos_emb, mask): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor - (#batch, 2*time1-1, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, 2*time1-1) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k - ) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask) - - -class MultiHeadedAttentionSANM(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionSANM, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - # self.linear_q = nn.Linear(n_feat, n_feat) - # self.linear_k = nn.Linear(n_feat, n_feat) - # self.linear_v = nn.Linear(n_feat, n_feat) - if lora_list is not None: - if "o" in lora_list: - self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list] - if lora_qkv_list == [False, False, False]: - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - else: - self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False) - # padding - left_padding = (kernel_size - 1) // 2 - if sanm_shfit > 0: - left_padding = left_padding + sanm_shfit - right_padding = kernel_size - 1 - left_padding - self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) - - def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): - b, t, d = inputs.size() - if mask is not None: - mask = torch.reshape(mask, (b, -1, 1)) - if mask_shfit_chunk is not None: - mask = mask * mask_shfit_chunk - inputs = inputs * mask - - x = inputs.transpose(1, 2) - x = self.pad_fn(x) - x = self.fsmn_block(x) - x = x.transpose(1, 2) - x += inputs - x = self.dropout(x) - if mask is not None: - x = x * mask - return x - - def forward_qkv(self, x): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - b, t, d = x.size() - q_k_v = self.linear_q_k_v(x) - q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) - q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - return q_h, k_h, v_h, v - - def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - if mask_att_chunk_encoder is not None: - mask = mask * mask_att_chunk_encoder - - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) - return att_outs + fsmn_memory - - def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - if chunk_size is not None and look_back > 0 or look_back == -1: - if cache is not None: - k_h_stride = k_h[:, :, :-(chunk_size[2]), :] - v_h_stride = v_h[:, :, :-(chunk_size[2]), :] - k_h = torch.cat((cache["k"], k_h), dim=2) - v_h = torch.cat((cache["v"], v_h), dim=2) - - cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) - cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) - if look_back != -1: - cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :] - cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :] - else: - cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :], - "v": v_h[:, :, :-(chunk_size[2]), :]} - cache = cache_tmp - fsmn_memory = self.forward_fsmn(v, None) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, None) - return att_outs + fsmn_memory, cache class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): @@ -506,586 +31,4 @@ class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) return att_outs + fsmn_memory -class MultiHeadedAttentionSANMDecoder(nn.Module): - """Multi-Head Attention layer. - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionSANMDecoder, self).__init__() - - self.dropout = nn.Dropout(p=dropout_rate) - - self.fsmn_block = nn.Conv1d(n_feat, n_feat, - kernel_size, stride=1, padding=0, groups=n_feat, bias=False) - # padding - # padding - left_padding = (kernel_size - 1) // 2 - if sanm_shfit > 0: - left_padding = left_padding + sanm_shfit - right_padding = kernel_size - 1 - left_padding - self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) - self.kernel_size = kernel_size - - def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None): - ''' - :param x: (#batch, time1, size). - :param mask: Mask tensor (#batch, 1, time) - :return: - ''' - # print("in fsmn, inputs", inputs.size()) - b, t, d = inputs.size() - # logging.info( - # "mask: {}".format(mask.size())) - if mask is not None: - mask = torch.reshape(mask, (b ,-1, 1)) - # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :])) - if mask_shfit_chunk is not None: - # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :])) - mask = mask * mask_shfit_chunk - # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :])) - # print("in fsmn, mask", mask.size()) - # print("in fsmn, inputs", inputs.size()) - inputs = inputs * mask - - x = inputs.transpose(1, 2) - b, d, t = x.size() - if cache is None: - # print("in fsmn, cache is None, x", x.size()) - - x = self.pad_fn(x) - if not self.training: - cache = x - else: - # print("in fsmn, cache is not None, x", x.size()) - # x = torch.cat((x, cache), dim=2)[:, :, :-1] - # if t < self.kernel_size: - # x = self.pad_fn(x) - x = torch.cat((cache[:, :, 1:], x), dim=2) - x = x[:, :, -(self.kernel_size+t-1):] - # print("in fsmn, cache is not None, x_cat", x.size()) - cache = x - x = self.fsmn_block(x) - x = x.transpose(1, 2) - # print("in fsmn, fsmn_out", x.size()) - if x.size(1) != inputs.size(1): - inputs = inputs[:, -1, :] - - x = x + inputs - x = self.dropout(x) - if mask is not None: - x = x * mask - return x, cache - -class MultiHeadedAttentionCrossAtt(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttentionCrossAtt, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - if lora_list is not None: - if "q" in lora_list: - self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_q = nn.Linear(n_feat, n_feat) - lora_kv_list = ["k" in lora_list, "v" in lora_list] - if lora_kv_list == [False, False]: - self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2) - else: - self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, - r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list) - if "o" in lora_list: - self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - else: - self.linear_out = nn.Linear(n_feat, n_feat) - else: - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2) - self.linear_out = nn.Linear(n_feat, n_feat) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, x, memory): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - - # print("in forward_qkv, x", x.size()) - b = x.size(0) - q = self.linear_q(x) - q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - - k_v = self.linear_k_v(memory) - k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1) - k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - - return q_h, k_h, v_h - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - # logging.info( - # "scores: {}, mask_size: {}".format(scores.size(), mask.size())) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, memory, memory_mask): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h = self.forward_qkv(x, memory) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, memory_mask) - - def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h = self.forward_qkv(x, memory) - if chunk_size is not None and look_back > 0: - if cache is not None: - k_h = torch.cat((cache["k"], k_h), dim=2) - v_h = torch.cat((cache["v"], v_h), dim=2) - cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :] - cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :] - else: - cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :], - "v": v_h[:, :, -(look_back * chunk_size[1]):, :]} - cache = cache_tmp - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, None), cache - - -class MultiHeadSelfAttention(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, n_head, in_feat, n_feat, dropout_rate): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadSelfAttention, self).__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_out = nn.Linear(n_feat, n_feat) - self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) - self.attn = None - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, x): - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - - """ - b, t, d = x.size() - q_k_v = self.linear_q_k_v(x) - q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) - q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) - k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) - - return q_h, k_h, v_h, v - - def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - if mask is not None: - if mask_att_chunk_encoder is not None: - mask = mask * mask_att_chunk_encoder - - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - - min_value = float( - numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min - ) - scores = scores.masked_fill(mask, min_value) - self.attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) - else: - self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(self.attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, x, mask, mask_att_chunk_encoder=None): - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - - """ - q_h, k_h, v_h, v = self.forward_qkv(x) - q_h = q_h * self.d_k ** (-0.5) - scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) - return att_outs - -class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): - """RelPositionMultiHeadedAttention definition. - Args: - num_heads: Number of attention heads. - embed_size: Embedding size. - dropout_rate: Dropout rate. - """ - - def __init__( - self, - num_heads: int, - embed_size: int, - dropout_rate: float = 0.0, - simplified_attention_score: bool = False, - ) -> None: - """Construct an MultiHeadedAttention object.""" - super().__init__() - - self.d_k = embed_size // num_heads - self.num_heads = num_heads - - assert self.d_k * num_heads == embed_size, ( - "embed_size (%d) must be divisible by num_heads (%d)", - (embed_size, num_heads), - ) - - self.linear_q = torch.nn.Linear(embed_size, embed_size) - self.linear_k = torch.nn.Linear(embed_size, embed_size) - self.linear_v = torch.nn.Linear(embed_size, embed_size) - - self.linear_out = torch.nn.Linear(embed_size, embed_size) - - if simplified_attention_score: - self.linear_pos = torch.nn.Linear(embed_size, num_heads) - - self.compute_att_score = self.compute_simplified_attention_score - else: - self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) - - self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - self.compute_att_score = self.compute_attention_score - - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.attn = None - - def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: - """Compute relative positional encoding. - Args: - x: Input sequence. (B, H, T_1, 2 * T_1 - 1) - left_context: Number of frames in left context. - Returns: - x: Output sequence. (B, H, T_1, T_2) - """ - batch_size, n_heads, time1, n = x.shape - time2 = time1 + left_context - - batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() - - return x.as_strided( - (batch_size, n_heads, time1, time2), - (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), - storage_offset=(n_stride * (time1 - 1)), - ) - - def compute_simplified_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Simplified attention score computation. - Reference: https://github.com/k2-fsa/icefall/pull/458 - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - pos_enc = self.linear_pos(pos_enc) - - matrix_ac = torch.matmul(query, key.transpose(2, 3)) - - matrix_bd = self.rel_shift( - pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), - left_context=left_context, - ) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def compute_attention_score( - self, - query: torch.Tensor, - key: torch.Tensor, - pos_enc: torch.Tensor, - left_context: int = 0, - ) -> torch.Tensor: - """Attention score computation. - Args: - query: Transformed query tensor. (B, H, T_1, d_k) - key: Transformed key tensor. (B, H, T_2, d_k) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - left_context: Number of frames in left context. - Returns: - : Attention score. (B, H, T_1, T_2) - """ - p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) - - query = query.transpose(1, 2) - q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) - q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) - - matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) - - matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) - matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) - - return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transform query, key and value. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - v: Value tensor. (B, T_2, size) - Returns: - q: Transformed query tensor. (B, H, T_1, d_k) - k: Transformed key tensor. (B, H, T_2, d_k) - v: Transformed value tensor. (B, H, T_2, d_k) - """ - n_batch = query.size(0) - - q = ( - self.linear_q(query) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - k = ( - self.linear_k(key) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - v = ( - self.linear_v(value) - .view(n_batch, -1, self.num_heads, self.d_k) - .transpose(1, 2) - ) - - return q, k, v - - def forward_attention( - self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Compute attention context vector. - Args: - value: Transformed value. (B, H, T_2, d_k) - scores: Attention score. (B, H, T_1, T_2) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - Returns: - attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) - """ - batch_size = scores.size(0) - mask = mask.unsqueeze(1).unsqueeze(2) - if chunk_mask is not None: - mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask - scores = scores.masked_fill(mask, float("-inf")) - self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - - attn_output = self.dropout(self.attn) - attn_output = torch.matmul(attn_output, value) - - attn_output = self.linear_out( - attn_output.transpose(1, 2) - .contiguous() - .view(batch_size, -1, self.num_heads * self.d_k) - ) - - return attn_output - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_enc: torch.Tensor, - mask: torch.Tensor, - chunk_mask: Optional[torch.Tensor] = None, - left_context: int = 0, - ) -> torch.Tensor: - """Compute scaled dot product attention with rel. positional encoding. - Args: - query: Query tensor. (B, T_1, size) - key: Key tensor. (B, T_2, size) - value: Value tensor. (B, T_2, size) - pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) - mask: Source mask. (B, T_2) - chunk_mask: Chunk mask. (T_1, T_1) - left_context: Number of frames in left context. - Returns: - : Output tensor. (B, T_1, H * d_k) - """ - q, k, v = self.forward_qkv(query, key, value) - scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) - return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) - - -class CosineDistanceAttention(nn.Module): - """ Compute Cosine Distance between spk decoder output and speaker profile - Args: - profile_path: speaker profile file path (.npy file) - """ - - def __init__(self): - super().__init__() - self.softmax = nn.Softmax(dim=-1) - - def forward(self, spk_decoder_out, profile, profile_lens=None): - """ - Args: - spk_decoder_out(torch.Tensor):(B, L, D) - spk_profiles(torch.Tensor):(B, N, D) - """ - x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) - if profile_lens is not None: - - mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) - min_value = float( - numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min - ) - weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value) - weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N) - else: - x = x[:, -1:, :, :] - weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1) - weights = self.softmax(weights_not_softmax) # (B, 1, N) - spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) - - return spk_embedding, weights diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py index 784baf37d..32ee2f2f2 100644 --- a/funasr/models/ct_transformer_streaming/encoder.py +++ b/funasr/models/ct_transformer_streaming/encoder.py @@ -12,7 +12,7 @@ import numpy as np from funasr.train_utils.device_funcs import to_device from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.sanm.attention import MultiHeadedAttention -from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask +from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py index 4c84261b6..5254d15e1 100644 --- a/funasr/models/ct_transformer_streaming/model.py +++ b/funasr/models/ct_transformer_streaming/model.py @@ -12,11 +12,12 @@ import torch import torch.nn as nn from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words from funasr.utils.load_utils import load_audio_text_image_video +from funasr.models.ct_transformer.model import CTTransformer from funasr.register import tables @tables.register("model_classes", "CTTransformerStreaming") -class CTTransformerStreaming(nn.Module): +class CTTransformerStreaming(CTTransformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection @@ -24,43 +25,13 @@ class CTTransformerStreaming(nn.Module): """ def __init__( self, - encoder: str = None, - encoder_conf: dict = None, - vocab_size: int = -1, - punc_list: list = None, - punc_weight: list = None, - embed_unit: int = 128, - att_unit: int = 256, - dropout_rate: float = 0.5, - ignore_id: int = -1, - sos: int = 1, - eos: int = 2, - sentence_end_id: int = 3, + *args, **kwargs, ): - super().__init__() + super().__init__(*args, **kwargs) - punc_size = len(punc_list) - if punc_weight is None: - punc_weight = [1] * punc_size - - - self.embed = nn.Embedding(vocab_size, embed_unit) - encoder_class = tables.encoder_classes.get(encoder.lower()) - encoder = encoder_class(**encoder_conf) - self.decoder = nn.Linear(att_unit, punc_size) - self.encoder = encoder - self.punc_list = punc_list - self.punc_weight = punc_weight - self.ignore_id = ignore_id - self.sos = sos - self.eos = eos - self.sentence_end_id = sentence_end_id - - - - def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs): """Compute loss value from buffer sequences. Args: @@ -70,146 +41,14 @@ class CTTransformerStreaming(nn.Module): """ x = self.embed(text) # mask = self._target_mask(input) - h, _, _ = self.encoder(x, text_lengths) + h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes) y = self.decoder(h) return y, None def with_vad(self): - return False - - def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: - """Score new token. - - Args: - y (torch.Tensor): 1D torch.int64 prefix tokens. - state: Scorer state for prefix tokens - x (torch.Tensor): encoder feature that generates ys. - - Returns: - tuple[torch.Tensor, Any]: Tuple of - torch.float32 scores for next token (vocab_size) - and next state for ys - - """ - y = y.unsqueeze(0) - h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1).squeeze(0) - return logp, cache - - def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: - """Score new token batch. - - Args: - ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). - states (List[Any]): Scorer states for prefix tokens. - xs (torch.Tensor): - The encoder feature that generates ys (n_batch, xlen, n_feat). - - Returns: - tuple[torch.Tensor, List[Any]]: Tuple of - batchfied scores for next token with shape of `(n_batch, vocab_size)` - and next state list for ys. - - """ - # merge states - n_batch = len(ys) - n_layers = len(self.encoder.encoders) - if states[0] is None: - batch_state = None - else: - # transpose state of [batch, layer] into [layer, batch] - batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] - - # batch decoding - h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1) - - # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] - return logp, state_list - - def nll( - self, - text: torch.Tensor, - punc: torch.Tensor, - text_lengths: torch.Tensor, - punc_lengths: torch.Tensor, - max_length: Optional[int] = None, - vad_indexes: Optional[torch.Tensor] = None, - vad_indexes_lengths: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute negative log likelihood(nll) - - Normally, this function is called in batchify_nll. - Args: - text: (Batch, Length) - punc: (Batch, Length) - text_lengths: (Batch,) - max_lengths: int - """ - batch_size = text.size(0) - # For data parallel - if max_length is None: - text = text[:, :text_lengths.max()] - punc = punc[:, :text_lengths.max()] - else: - text = text[:, :max_length] - punc = punc[:, :max_length] - - if self.with_vad(): - # Should be VadRealtimeTransformer - assert vad_indexes is not None - y, _ = self.punc_forward(text, text_lengths, vad_indexes) - else: - # Should be TargetDelayTransformer, - y, _ = self.punc_forward(text, text_lengths) - - # Calc negative log likelihood - # nll: (BxL,) - if self.training == False: - _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) - from sklearn.metrics import f1_score - f1_score = f1_score(punc.view(-1).detach().cpu().numpy(), - indices.squeeze(-1).detach().cpu().numpy(), - average='micro') - nll = torch.Tensor([f1_score]).repeat(text_lengths.sum()) - return nll, text_lengths - else: - self.punc_weight = self.punc_weight.to(punc.device) - nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", - ignore_index=self.ignore_id) - # nll: (BxL,) -> (BxL,) - if max_length is None: - nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) - else: - nll.masked_fill_( - make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1), - 0.0, - ) - # nll: (BxL,) -> (B, L) - nll = nll.view(batch_size, -1) - return nll, text_lengths + return True - def forward( - self, - text: torch.Tensor, - punc: torch.Tensor, - text_lengths: torch.Tensor, - punc_lengths: torch.Tensor, - vad_indexes: Optional[torch.Tensor] = None, - vad_indexes_lengths: Optional[torch.Tensor] = None, - ): - nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) - ntokens = y_lengths.sum() - loss = nll.sum() / ntokens - stats = dict(loss=loss.detach()) - - # force_gatherable: to-device and to-tensor if scalar for DataParallel - loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) - return loss, stats, weight def generate(self, data_in, @@ -217,22 +56,20 @@ class CTTransformerStreaming(nn.Module): key: list = None, tokenizer=None, frontend=None, + cache: dict = {}, **kwargs, ): assert len(data_in) == 1 + + if len(cache) == 0: + cache["pre_text"] = [] text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0] - vad_indexes = kwargs.get("vad_indexes", None) - # text = data_in[0] - # text_lengths = data_lengths[0] if data_lengths is not None else None + text = "".join(cache["pre_text"]) + " " + text + + split_size = kwargs.get("split_size", 20) - jieba_usr_dict = kwargs.get("jieba_usr_dict", None) - if jieba_usr_dict and isinstance(jieba_usr_dict, str): - import jieba - jieba.load_userdict(jieba_usr_dict) - jieba_usr_dict = jieba - kwargs["jieba_usr_dict"] = "jieba_usr_dict" - tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) + tokens = split_words(text) tokens_int = tokenizer.encode(tokens) mini_sentences = split_to_mini_sentence(tokens, split_size) @@ -240,8 +77,9 @@ class CTTransformerStreaming(nn.Module): assert len(mini_sentences) == len(mini_sentences_id) cache_sent = [] cache_sent_id = torch.from_numpy(np.array([], dtype='int32')) - new_mini_sentence = "" - new_mini_sentence_punc = [] + skip_num = 0 + sentence_punc_list = [] + sentence_words_list = [] cache_pop_trigger_limit = 200 results = [] meta_data = {} @@ -254,6 +92,7 @@ class CTTransformerStreaming(nn.Module): data = { "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), + "vad_indexes": torch.from_numpy(np.array([len(cache["pre_text"])], dtype='int32')), } data = to_device(data, kwargs["device"]) # y, _ = self.wrapped_model(**data) @@ -288,52 +127,42 @@ class CTTransformerStreaming(nn.Module): # continue punctuations_np = punctuations.cpu().numpy() - new_mini_sentence_punc += [int(x) for x in punctuations_np] - words_with_punc = [] - for i in range(len(mini_sentence)): - if (i==0 or self.punc_list[punctuations[i-1]] == "。" or self.punc_list[punctuations[i-1]] == "?") and len(mini_sentence[i][0].encode()) == 1: - mini_sentence[i] = mini_sentence[i].capitalize() - if i == 0: - if len(mini_sentence[i][0].encode()) == 1: - mini_sentence[i] = " " + mini_sentence[i] - if i > 0: - if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: - mini_sentence[i] = " " + mini_sentence[i] - words_with_punc.append(mini_sentence[i]) - if self.punc_list[punctuations[i]] != "_": - punc_res = self.punc_list[punctuations[i]] - if len(mini_sentence[i][0].encode()) == 1: - if punc_res == ",": - punc_res = "," - elif punc_res == "。": - punc_res = "." - elif punc_res == "?": - punc_res = "?" - words_with_punc.append(punc_res) - new_mini_sentence += "".join(words_with_punc) - # Add Period for the end of the sentence - new_mini_sentence_out = new_mini_sentence - new_mini_sentence_punc_out = new_mini_sentence_punc - if mini_sentence_i == len(mini_sentences) - 1: - if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": - new_mini_sentence_out = new_mini_sentence[:-1] + "。" - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] - elif new_mini_sentence[-1] == ",": - new_mini_sentence_out = new_mini_sentence[:-1] + "." - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] - elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0: - new_mini_sentence_out = new_mini_sentence + "。" - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] - elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: - new_mini_sentence_out = new_mini_sentence + "." - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] - # keep a punctuations array for punc segment - if punc_array is None: - punc_array = punctuations + sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np] + sentence_words_list += mini_sentence + + assert len(sentence_punc_list) == len(sentence_words_list) + words_with_punc = [] + sentence_punc_list_out = [] + for i in range(0, len(sentence_words_list)): + if i > 0: + if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1: + sentence_words_list[i] = " " + sentence_words_list[i] + if skip_num < len(cache["pre_text"]): + skip_num += 1 else: - punc_array = torch.cat([punc_array, punctuations], dim=0) + words_with_punc.append(sentence_words_list[i]) + if skip_num >= len(cache["pre_text"]): + sentence_punc_list_out.append(sentence_punc_list[i]) + if sentence_punc_list[i] != "_": + words_with_punc.append(sentence_punc_list[i]) + sentence_out = "".join(words_with_punc) + + sentenceEnd = -1 + for i in range(len(sentence_punc_list) - 2, 1, -1): + if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?": + sentenceEnd = i + break + cache["pre_text"] = sentence_words_list[sentenceEnd + 1:] + if sentence_out[-1] in self.punc_list: + sentence_out = sentence_out[:-1] + sentence_punc_list_out[-1] = "_" + # keep a punctuations array for punc segment + if punc_array is None: + punc_array = punctuations + else: + punc_array = torch.cat([punc_array, punctuations], dim=0) - result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} + result_i = {"key": key[0], "text": sentence_out, "punc_array": punc_array} results.append(result_i) return results, meta_data diff --git a/funasr/models/ct_transformer_streaming/template.yaml b/funasr/models/ct_transformer_streaming/template.yaml index c20a09896..2477ac2be 100644 --- a/funasr/models/ct_transformer_streaming/template.yaml +++ b/funasr/models/ct_transformer_streaming/template.yaml @@ -27,13 +27,13 @@ model_conf: - 1.0 sentence_end_id: 3 -encoder: SANMEncoder +encoder: SANMVadEncoder encoder_conf: input_size: 256 output_size: 256 attention_heads: 8 linear_units: 1024 - num_blocks: 4 + num_blocks: 3 dropout_rate: 0.1 positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 @@ -41,13 +41,10 @@ encoder_conf: pos_enc_class: SinusoidalPositionEncoder normalize_before: true kernel_size: 11 - sanm_shfit: 0 + sanm_shfit: 5 selfattention_layer_type: sanm padding_idx: 0 tokenizer: CharTokenizer tokenizer_conf: - unk_symbol: - - - + unk_symbol: \ No newline at end of file diff --git a/funasr/models/ct_transformer_streaming/utils.py b/funasr/models/ct_transformer_streaming/utils.py deleted file mode 100644 index 917f2e035..000000000 --- a/funasr/models/ct_transformer_streaming/utils.py +++ /dev/null @@ -1,111 +0,0 @@ -import re - -def split_to_mini_sentence(words: list, word_limit: int = 20): - assert word_limit > 1 - if len(words) <= word_limit: - return [words] - sentences = [] - length = len(words) - sentence_len = length // word_limit - for i in range(sentence_len): - sentences.append(words[i * word_limit:(i + 1) * word_limit]) - if length % word_limit > 0: - sentences.append(words[sentence_len * word_limit:]) - return sentences - - -# def split_words(text: str, **kwargs): -# words = [] -# segs = text.split() -# for seg in segs: -# # There is no space in seg. -# current_word = "" -# for c in seg: -# if len(c.encode()) == 1: -# # This is an ASCII char. -# current_word += c -# else: -# # This is a Chinese char. -# if len(current_word) > 0: -# words.append(current_word) -# current_word = "" -# words.append(c) -# if len(current_word) > 0: -# words.append(current_word) -# -# return words - -def split_words(text: str, jieba_usr_dict=None, **kwargs): - if jieba_usr_dict: - input_list = text.split() - token_list_all = [] - langauge_list = [] - token_list_tmp = [] - language_flag = None - for token in input_list: - if isEnglish(token) and language_flag == 'Chinese': - token_list_all.append(token_list_tmp) - langauge_list.append('Chinese') - token_list_tmp = [] - elif not isEnglish(token) and language_flag == 'English': - token_list_all.append(token_list_tmp) - langauge_list.append('English') - token_list_tmp = [] - - token_list_tmp.append(token) - - if isEnglish(token): - language_flag = 'English' - else: - language_flag = 'Chinese' - - if token_list_tmp: - token_list_all.append(token_list_tmp) - langauge_list.append(language_flag) - - result_list = [] - for token_list_tmp, language_flag in zip(token_list_all, langauge_list): - if language_flag == 'English': - result_list.extend(token_list_tmp) - else: - seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False) - result_list.extend(seg_list) - - return result_list - - else: - words = [] - segs = text.split() - for seg in segs: - # There is no space in seg. - current_word = "" - for c in seg: - if len(c.encode()) == 1: - # This is an ASCII char. - current_word += c - else: - # This is a Chinese char. - if len(current_word) > 0: - words.append(current_word) - current_word = "" - words.append(c) - if len(current_word) > 0: - words.append(current_word) - return words - -def isEnglish(text:str): - if re.search('^[a-zA-Z\']+$', text): - return True - else: - return False - -def join_chinese_and_english(input_list): - line = '' - for token in input_list: - if isEnglish(token): - line = line + ' ' + token - else: - line = line + token - - line = line.strip() - return line diff --git a/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py b/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py deleted file mode 100644 index 155057ce8..000000000 --- a/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Any -from typing import List -from typing import Tuple - -import torch -import torch.nn as nn - -from funasr.models.transformer.embedding import SinusoidalPositionEncoder -from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder - - -class VadRealtimeTransformer(torch.nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection - https://arxiv.org/pdf/2003.01309.pdf - """ - def __init__( - self, - vocab_size: int, - punc_size: int, - pos_enc: str = None, - embed_unit: int = 128, - att_unit: int = 256, - head: int = 2, - unit: int = 1024, - layer: int = 4, - dropout_rate: float = 0.5, - kernel_size: int = 11, - sanm_shfit: int = 0, - ): - super().__init__() - if pos_enc == "sinusoidal": - # pos_enc_class = PositionalEncoding - pos_enc_class = SinusoidalPositionEncoder - elif pos_enc is None: - - def pos_enc_class(*args, **kwargs): - return nn.Sequential() # indentity - - else: - raise ValueError(f"unknown pos-enc option: {pos_enc}") - - self.embed = nn.Embedding(vocab_size, embed_unit) - self.encoder = Encoder( - input_size=embed_unit, - output_size=att_unit, - attention_heads=head, - linear_units=unit, - num_blocks=layer, - dropout_rate=dropout_rate, - input_layer="pe", - # pos_enc_class=pos_enc_class, - padding_idx=0, - kernel_size=kernel_size, - sanm_shfit=sanm_shfit, - ) - self.decoder = nn.Linear(att_unit, punc_size) - - -# def _target_mask(self, ys_in_pad): -# ys_mask = ys_in_pad != 0 -# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) -# return ys_mask.unsqueeze(-2) & m - - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, - vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: - """Compute loss value from buffer sequences. - - Args: - input (torch.Tensor): Input ids. (batch, len) - hidden (torch.Tensor): Target ids. (batch, len) - - """ - x = self.embed(input) - # mask = self._target_mask(input) - h, _, _ = self.encoder(x, text_lengths, vad_indexes) - y = self.decoder(h) - return y, None - - def with_vad(self): - return True - - def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: - """Score new token. - - Args: - y (torch.Tensor): 1D torch.int64 prefix tokens. - state: Scorer state for prefix tokens - x (torch.Tensor): encoder feature that generates ys. - - Returns: - tuple[torch.Tensor, Any]: Tuple of - torch.float32 scores for next token (vocab_size) - and next state for ys - - """ - y = y.unsqueeze(0) - h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1).squeeze(0) - return logp, cache - - def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: - """Score new token batch. - - Args: - ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). - states (List[Any]): Scorer states for prefix tokens. - xs (torch.Tensor): - The encoder feature that generates ys (n_batch, xlen, n_feat). - - Returns: - tuple[torch.Tensor, List[Any]]: Tuple of - batchfied scores for next token with shape of `(n_batch, vocab_size)` - and next state list for ys. - - """ - # merge states - n_batch = len(ys) - n_layers = len(self.encoder.encoders) - if states[0] is None: - batch_state = None - else: - # transpose state of [batch, layer] into [layer, batch] - batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] - - # batch decoding - h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1) - - # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] - return logp, state_list diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index a6596a00a..ef9d93a14 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -125,3 +125,4 @@ def load_pretrained_model( logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) dst_state.update(src_state) obj.load_state_dict(dst_state) + \ No newline at end of file