diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index ac4240cdb..e04b9e716 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -945,11 +945,11 @@ class SANMEncoderChunkOpt(AbsEncoder): for layer_idx, encoder_layer in enumerate(self.encoders): encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"]) - xs_pad, new_cache[layer_idx+1] = encoder_outs[0], encoder_outs[1] + xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1] if self.normalize_before: xs_pad = self.after_norm(xs_pad) - if cache["encoder_chunk_look_back"] > 0: + if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1: cache["opt"] = new_cache return xs_pad, ilens, None diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index f5430e1ca..157a2c50c 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -471,15 +471,21 @@ class MultiHeadedAttentionSANM(nn.Module): """ q_h, k_h, v_h, v = self.forward_qkv(x) - if chunk_size is not None and look_back > 0: + if chunk_size is not None and look_back > 0 or look_back == -1: if cache is not None: + k_h_stride = k_h[:, :, :-(chunk_size[2]), :] + v_h_stride = v_h[:, :, :-(chunk_size[2]), :] k_h = torch.cat((cache["k"], k_h), dim=2) v_h = torch.cat((cache["v"], v_h), dim=2) - cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :] - cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :] + + cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) + cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) + if look_back != -1: + cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :] + cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :] else: - cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :], - "v": v_h[:, :, -(look_back * chunk_size[1]):, :]} + cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :], + "v": v_h[:, :, :-(chunk_size[2]), :]} cache = cache_tmp fsmn_memory = self.forward_fsmn(v, None) q_h = q_h * self.d_k ** (-0.5)