diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 011743022..3bfcffc3f 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -90,6 +90,47 @@ class DecoderLayerSANM(nn.Module): tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + x, _ = self.self_attn(tgt, tgt_mask) + x = residual + self.dropout(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + + + return x, tgt_mask, memory, memory_mask, cache + + def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + # tgt = self.dropout(tgt) + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + x = tgt if self.self_attn: if self.normalize_before: @@ -109,7 +150,6 @@ class DecoderLayerSANM(nn.Module): return x, tgt_mask, memory, memory_mask, cache - class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ author: Speech Lab, Alibaba Group, China @@ -980,7 +1020,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): new_cache = cache["decode_fsmn"] for i in range(self.att_layer_num): decoder = self.decoders[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, None, memory, None, cache=new_cache[i] ) new_cache[i] = c_ret @@ -989,14 +1029,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): for i in range(self.num_blocks - self.att_layer_num): j = i + self.att_layer_num decoder = self.decoders2[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, None, memory, None, cache=new_cache[j] ) new_cache[j] = c_ret for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( x, None, memory, None, cache=None ) if self.normalize_before: