diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 3e4e554f2..bbfe0ef4f 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -105,48 +105,48 @@ class DecoderLayerSANM(nn.Module): return x, tgt_mask, memory, memory_mask, cache - #def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): - # """Compute decoded features. + def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + """Compute decoded features. - # Args: - # tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). - # tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). - # memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). - # memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). - # cache (List[torch.Tensor]): List of cached tensors. - # Each tensor shape should be (#batch, maxlen_out - 1, size). + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). - # Returns: - # torch.Tensor: Output tensor(#batch, maxlen_out, size). - # torch.Tensor: Mask for output tensor (#batch, maxlen_out). - # torch.Tensor: Encoded memory (#batch, maxlen_in, size). - # torch.Tensor: Encoded memory mask (#batch, maxlen_in). + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). - # """ - # # tgt = self.dropout(tgt) - # residual = tgt - # if self.normalize_before: - # tgt = self.norm1(tgt) - # tgt = self.feed_forward(tgt) + """ + # tgt = self.dropout(tgt) + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) - # x = tgt - # if self.self_attn: - # if self.normalize_before: - # tgt = self.norm2(tgt) - # if self.training: - # cache = None - # x, cache = self.self_attn(tgt, tgt_mask, cache=cache) - # x = residual + self.dropout(x) + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + if self.training: + cache = None + x, cache = self.self_attn(tgt, tgt_mask, cache=cache) + x = residual + self.dropout(x) - # if self.src_attn is not None: - # residual = x - # if self.normalize_before: - # x = self.norm3(x) + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) - # x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) - # return x, tgt_mask, memory, memory_mask, cache + return x, tgt_mask, memory, memory_mask, cache def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0): """Compute decoded features. @@ -438,7 +438,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): for i in range(self.att_layer_num): decoder = self.decoders[i] c = cache[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) @@ -448,13 +448,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): j = i + self.att_layer_num decoder = self.decoders2[i] c = cache[j] - x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, memory_mask, cache=c ) new_cache.append(c_ret) for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( x, tgt_mask, memory, None, cache=None ) @@ -878,6 +878,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): lora_rank: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.1, + chunk_multiply_factor: tuple = (1,), tf2torch_tensor_name_prefix_torch: str = "decoder", tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", ): @@ -970,6 +971,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): ) self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + self.chunk_multiply_factor = chunk_multiply_factor def forward( self, @@ -1190,7 +1192,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): for i in range(self.att_layer_num): decoder = self.decoders[i] c = cache[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, None, cache=c ) new_cache.append(c_ret) @@ -1200,14 +1202,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): j = i + self.att_layer_num decoder = self.decoders2[i] c = cache[j] - x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( x, tgt_mask, memory, None, cache=c ) new_cache.append(c_ret) for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( x, tgt_mask, memory, None, cache=None )