fix decoder cache

This commit is contained in:
haoneng.lhn 2023-04-03 19:57:26 +08:00
parent 0ca6876f58
commit 6be782d9fd

View File

@ -94,7 +94,7 @@ class DecoderLayerSANM(nn.Module):
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x, _ = self.self_attn(tgt, tgt_mask)
x = residual + self.dropout(x)
if self.src_attn is not None:
@ -399,7 +399,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@ -409,13 +409,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
@ -1076,7 +1076,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@ -1086,14 +1086,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)