diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index a36d95e77..463918a0f 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -94,7 +94,7 @@ class DecoderLayerSANM(nn.Module): if self.self_attn: if self.normalize_before: tgt = self.norm2(tgt) - x, cache = self.self_attn(tgt, tgt_mask, cache=cache) + x, _ = self.self_attn(tgt, tgt_mask) x = residual + self.dropout(x) if self.src_attn is not None: @@ -399,7 +399,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): for i in range(self.att_layer_num): decoder = self.decoders[i] c = cache[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) @@ -409,13 +409,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): j = i + self.att_layer_num decoder = self.decoders2[i] c = cache[j] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( x, tgt_mask, memory, None, cache=None ) @@ -1076,7 +1076,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): for i in range(self.att_layer_num): decoder = self.decoders[i] c = cache[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, tgt_mask, memory, None, cache=c ) new_cache.append(c_ret) @@ -1086,14 +1086,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): j = i + self.att_layer_num decoder = self.decoders2[i] c = cache[j] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, tgt_mask, memory, None, cache=c ) new_cache.append(c_ret) for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( x, tgt_mask, memory, None, cache=None )