from typing import List from typing import Tuple import logging import torch import torch.nn as nn import numpy as np from funasr.models.scama import utils as myutils from funasr.models.transformer.decoder import BaseTransformerDecoder from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt from funasr.models.transformer.embedding import PositionalEncoding from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM from funasr.models.transformer.utils.repeat import repeat from funasr.register import tables class DecoderLayerSANM(nn.Module): """Single decoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` instance can be used as the argument. src_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ def __init__( self, size, self_attn, src_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, ): """Construct an DecoderLayer object.""" super(DecoderLayerSANM, self).__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.norm1 = LayerNorm(size) if self_attn is not None: self.norm2 = LayerNorm(size) if src_attn is not None: self.norm3 = LayerNorm(size) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear1 = nn.Linear(size + size, size) self.concat_linear2 = nn.Linear(size + size, size) def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): """Compute decoded features. Args: tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). Returns: torch.Tensor: Output tensor(#batch, maxlen_out, size). torch.Tensor: Mask for output tensor (#batch, maxlen_out). torch.Tensor: Encoded memory (#batch, maxlen_in, size). torch.Tensor: Encoded memory mask (#batch, maxlen_in). """ # tgt = self.dropout(tgt) residual = tgt if self.normalize_before: tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) x = tgt if self.self_attn: if self.normalize_before: tgt = self.norm2(tgt) x, _ = self.self_attn(tgt, tgt_mask) x = residual + self.dropout(x) if self.src_attn is not None: residual = x if self.normalize_before: x = self.norm3(x) x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) return x, tgt_mask, memory, memory_mask, cache def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): """Compute decoded features. Args: tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). Returns: torch.Tensor: Output tensor(#batch, maxlen_out, size). torch.Tensor: Mask for output tensor (#batch, maxlen_out). torch.Tensor: Encoded memory (#batch, maxlen_in, size). torch.Tensor: Encoded memory mask (#batch, maxlen_in). """ # tgt = self.dropout(tgt) residual = tgt if self.normalize_before: tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) x = tgt if self.self_attn: if self.normalize_before: tgt = self.norm2(tgt) if self.training: cache = None x, cache = self.self_attn(tgt, tgt_mask, cache=cache) x = residual + self.dropout(x) if self.src_attn is not None: residual = x if self.normalize_before: x = self.norm3(x) x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) return x, tgt_mask, memory, memory_mask, cache def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0): """Compute decoded features. Args: tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). Returns: torch.Tensor: Output tensor(#batch, maxlen_out, size). torch.Tensor: Mask for output tensor (#batch, maxlen_out). torch.Tensor: Encoded memory (#batch, maxlen_in, size). torch.Tensor: Encoded memory mask (#batch, maxlen_in). """ residual = tgt if self.normalize_before: tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) x = tgt if self.self_attn: if self.normalize_before: tgt = self.norm2(tgt) x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache) x = residual + self.dropout(x) if self.src_attn is not None: residual = x if self.normalize_before: x = self.norm3(x) x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back) x = residual + x return x, memory, fsmn_cache, opt_cache @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt") class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition https://arxiv.org/abs/2006.01713 """ def __init__( self, vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, att_layer_num: int = 6, kernel_size: int = 21, sanm_shfit: int = None, concat_embeds: bool = False, attention_dim: int = None, tf2torch_tensor_name_prefix_torch: str = "decoder", tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", embed_tensor_name_prefix_tf: str = None, ): super().__init__( vocab_size=vocab_size, encoder_output_size=encoder_output_size, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, ) if attention_dim is None: attention_dim = encoder_output_size if input_layer == "embed": self.embed = torch.nn.Sequential( torch.nn.Embedding(vocab_size, attention_dim), ) elif input_layer == "linear": self.embed = torch.nn.Sequential( torch.nn.Linear(vocab_size, attention_dim), torch.nn.LayerNorm(attention_dim), torch.nn.Dropout(dropout_rate), torch.nn.ReLU(), pos_enc_class(attention_dim, positional_dropout_rate), ) else: raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") self.normalize_before = normalize_before if self.normalize_before: self.after_norm = LayerNorm(attention_dim) if use_output_layer: self.output_layer = torch.nn.Linear(attention_dim, vocab_size) else: self.output_layer = None self.att_layer_num = att_layer_num self.num_blocks = num_blocks if sanm_shfit is None: sanm_shfit = (kernel_size - 1) // 2 self.decoders = repeat( att_layer_num, lambda lnum: DecoderLayerSANM( attention_dim, MultiHeadedAttentionSANMDecoder( attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit ), MultiHeadedAttentionCrossAtt( attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size ), PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), ) if num_blocks - att_layer_num <= 0: self.decoders2 = None else: self.decoders2 = repeat( num_blocks - att_layer_num, lambda lnum: DecoderLayerSANM( attention_dim, MultiHeadedAttentionSANMDecoder( attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit ), None, PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), ) self.decoders3 = repeat( 1, lambda lnum: DecoderLayerSANM( attention_dim, None, None, PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), dropout_rate, normalize_before, concat_after, ), ) if concat_embeds: self.embed_concat_ffn = repeat( 1, lambda lnum: DecoderLayerSANM( attention_dim + encoder_output_size, None, None, PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate, adim=attention_dim), dropout_rate, normalize_before, concat_after, ), ) else: self.embed_concat_ffn = None self.concat_embeds = concat_embeds self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, chunk_mask: torch.Tensor = None, pre_acoustic_embeds: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] memory = hs_pad memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] if chunk_mask is not None: memory_mask = memory_mask * chunk_mask if tgt_mask.size(1) != memory_mask.size(1): memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) x = self.embed(tgt) if pre_acoustic_embeds is not None and self.concat_embeds: x = torch.cat((x, pre_acoustic_embeds), dim=-1) x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None) x, tgt_mask, memory, memory_mask, _ = self.decoders( x, tgt_mask, memory, memory_mask ) if self.decoders2 is not None: x, tgt_mask, memory, memory_mask, _ = self.decoders2( x, tgt_mask, memory, memory_mask ) x, tgt_mask, memory, memory_mask, _ = self.decoders3( x, tgt_mask, memory, memory_mask ) if self.normalize_before: x = self.after_norm(x) if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) return x, olens def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ): """Score.""" ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None] logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=state ) return logp.squeeze(0), state def forward_one_step( self, tgt: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor = None, pre_acoustic_embeds: torch.Tensor = None, cache: List[torch.Tensor] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward one step. Args: tgt: input token ids, int64 (batch, maxlen_out) tgt_mask: input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) memory: encoded memory, float32 (batch, maxlen_in, feat) cache: cached output list of (batch, max_time_out-1, size) Returns: y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ x = tgt[:, -1:] tgt_mask = None x = self.embed(x) if pre_acoustic_embeds is not None and self.concat_embeds: x = torch.cat((x, pre_acoustic_embeds), dim=-1) x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None) if cache is None: cache_layer_num = len(self.decoders) if self.decoders2 is not None: cache_layer_num += len(self.decoders2) cache = [None] * cache_layer_num new_cache = [] # for c, decoder in zip(cache, self.decoders): for i in range(self.att_layer_num): decoder = self.decoders[i] c = cache[i] x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) if self.num_blocks - self.att_layer_num >= 1: for i in range(self.num_blocks - self.att_layer_num): j = i + self.att_layer_num decoder = self.decoders2[i] c = cache[j] x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) for decoder in self.decoders3: x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( x, tgt_mask, memory, None, cache=None ) if self.normalize_before: y = self.after_norm(x[:, -1]) else: y = x[:, -1] if self.output_layer is not None: y = self.output_layer(y) y = torch.log_softmax(y, dim=-1) return y, new_cache def gen_tf2torch_map_dict(self): tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf map_dict_local = { ## decoder # ffn "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (1024,256),(1,256,1024) "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (256,1024),(1,1024,256) # fsmn "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 2, 0), }, # (256,1,31),(1,31,256,1) # src att "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (256,256),(1,256,256) "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (1024,256),(1,256,1024) "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (256,256),(1,256,256) "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) # dnn "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (1024,256),(1,256,1024) "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (256,1024),(1,1024,256) # embed_concat_ffn "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (1024,256),(1,256,1024) "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (1024,),(1024,) "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), "squeeze": 0, "transpose": (1, 0), }, # (256,1024),(1,1024,256) # out norm "{}.after_norm.weight".format(tensor_name_prefix_torch): {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) "{}.after_norm.bias".format(tensor_name_prefix_torch): {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (256,),(256,) # in embed "{}.embed.0.weight".format(tensor_name_prefix_torch): {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf), "squeeze": None, "transpose": None, }, # (4235,256),(4235,256) # out layer "{}.output_layer.weight".format(tensor_name_prefix_torch): {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(embed_tensor_name_prefix_tf)], "squeeze": [None, None], "transpose": [(1, 0), None], }, # (4235,256),(256,4235) "{}.output_layer.bias".format(tensor_name_prefix_torch): {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], "squeeze": [None, None], "transpose": [None, None], }, # (4235,),(4235,) } return map_dict_local def convert_tf2torch(self, var_dict_tf, var_dict_torch, ): map_dict = self.gen_tf2torch_map_dict() var_dict_torch_update = dict() decoder_layeridx_sets = set() for name in sorted(var_dict_torch.keys(), reverse=False): names = name.split('.') if names[0] == self.tf2torch_tensor_name_prefix_torch: if names[1] == "decoders": layeridx = int(names[2]) name_q = name.replace(".{}.".format(layeridx), ".layeridx.") layeridx_bias = 0 layeridx += layeridx_bias decoder_layeridx_sets.add(layeridx) if name_q in map_dict.keys(): name_v = map_dict[name_q]["name"] name_tf = name_v.replace("layeridx", "{}".format(layeridx)) data_tf = var_dict_tf[name_tf] if map_dict[name_q]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) if map_dict[name_q]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, var_dict_tf[name_tf].shape)) elif names[1] == "decoders2": layeridx = int(names[2]) name_q = name.replace(".{}.".format(layeridx), ".layeridx.") name_q = name_q.replace("decoders2", "decoders") layeridx_bias = len(decoder_layeridx_sets) layeridx += layeridx_bias if "decoders." in name: decoder_layeridx_sets.add(layeridx) if name_q in map_dict.keys(): name_v = map_dict[name_q]["name"] name_tf = name_v.replace("layeridx", "{}".format(layeridx)) data_tf = var_dict_tf[name_tf] if map_dict[name_q]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) if map_dict[name_q]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, var_dict_tf[name_tf].shape)) elif names[1] == "decoders3": layeridx = int(names[2]) name_q = name.replace(".{}.".format(layeridx), ".layeridx.") layeridx_bias = 0 layeridx += layeridx_bias if "decoders." in name: decoder_layeridx_sets.add(layeridx) if name_q in map_dict.keys(): name_v = map_dict[name_q]["name"] name_tf = name_v.replace("layeridx", "{}".format(layeridx)) data_tf = var_dict_tf[name_tf] if map_dict[name_q]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) if map_dict[name_q]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, var_dict_tf[name_tf].shape)) elif names[1] == "embed" or names[1] == "output_layer": name_tf = map_dict[name]["name"] if isinstance(name_tf, list): idx_list = 0 if name_tf[idx_list] in var_dict_tf.keys(): pass else: idx_list = 1 data_tf = var_dict_tf[name_tf[idx_list]] if map_dict[name]["squeeze"][idx_list] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) if map_dict[name]["transpose"][idx_list] is not None: data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf[idx_list], var_dict_tf[name_tf[ idx_list]].shape)) else: data_tf = var_dict_tf[name_tf] if map_dict[name]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) if map_dict[name]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) elif names[1] == "after_norm": name_tf = map_dict[name]["name"] data_tf = var_dict_tf[name_tf] data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) elif names[1] == "embed_concat_ffn": layeridx = int(names[2]) name_q = name.replace(".{}.".format(layeridx), ".layeridx.") layeridx_bias = 0 layeridx += layeridx_bias if "decoders." in name: decoder_layeridx_sets.add(layeridx) if name_q in map_dict.keys(): name_v = map_dict[name_q]["name"] name_tf = name_v.replace("layeridx", "{}".format(layeridx)) data_tf = var_dict_tf[name_tf] if map_dict[name_q]["squeeze"] is not None: data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) if map_dict[name_q]["transpose"] is not None: data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, var_dict_torch[ name].size(), data_tf.size()) var_dict_torch_update[name] = data_tf logging.info( "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, var_dict_tf[name_tf].shape)) return var_dict_torch_update