From 8ae9fa8365eba7d33c8d8f5fa51d12903ca6a409 Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Thu, 21 Sep 2023 16:26:51 +0800 Subject: [PATCH] update --- funasr/models/decoder/sanm_decoder.py | 59 --------------------------- funasr/models/encoder/sanm_encoder.py | 46 --------------------- 2 files changed, 105 deletions(-) diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index bbfe0ef4f..ff35e463c 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -1035,65 +1035,6 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): ) return logp.squeeze(0), state - #def forward_chunk( - # self, - # memory: torch.Tensor, - # tgt: torch.Tensor, - # cache: dict = None, - #) -> Tuple[torch.Tensor, torch.Tensor]: - # """Forward decoder. - - # Args: - # hs_pad: encoded memory, float32 (batch, maxlen_in, feat) - # hlens: (batch) - # ys_in_pad: - # input token ids, int64 (batch, maxlen_out) - # if input_layer == "embed" - # input tensor (batch, maxlen_out, #mels) in the other cases - # ys_in_lens: (batch) - # Returns: - # (tuple): tuple containing: - - # x: decoded token score before softmax (batch, maxlen_out, token) - # if use_output_layer is True, - # olens: (batch, ) - # """ - # x = tgt - # if cache["decode_fsmn"] is None: - # cache_layer_num = len(self.decoders) - # if self.decoders2 is not None: - # cache_layer_num += len(self.decoders2) - # new_cache = [None] * cache_layer_num - # else: - # new_cache = cache["decode_fsmn"] - # for i in range(self.att_layer_num): - # decoder = self.decoders[i] - # x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( - # x, None, memory, None, cache=new_cache[i] - # ) - # new_cache[i] = c_ret - - # if self.num_blocks - self.att_layer_num > 1: - # for i in range(self.num_blocks - self.att_layer_num): - # j = i + self.att_layer_num - # decoder = self.decoders2[i] - # x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( - # x, None, memory, None, cache=new_cache[j] - # ) - # new_cache[j] = c_ret - - # for decoder in self.decoders3: - - # x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( - # x, None, memory, None, cache=None - # ) - # if self.normalize_before: - # x = self.after_norm(x) - # if self.output_layer is not None: - # x = self.output_layer(x) - # cache["decode_fsmn"] = new_cache - # return x - def forward_chunk( self, memory: torch.Tensor, diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index e04b9e716..c15343efb 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -873,52 +873,6 @@ class SANMEncoderChunkOpt(AbsEncoder): cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] return overlap_feats - #def forward_chunk(self, - # xs_pad: torch.Tensor, - # ilens: torch.Tensor, - # cache: dict = None, - # ctc: CTC = None, - # ): - # xs_pad *= self.output_size() ** 0.5 - # if self.embed is None: - # xs_pad = xs_pad - # else: - # xs_pad = self.embed(xs_pad, cache) - # if cache["tail_chunk"]: - # xs_pad = to_device(cache["feats"], device=xs_pad.device) - # else: - # xs_pad = self._add_overlap_chunk(xs_pad, cache) - # encoder_outs = self.encoders0(xs_pad, None, None, None, None) - # xs_pad, masks = encoder_outs[0], encoder_outs[1] - # intermediate_outs = [] - # if len(self.interctc_layer_idx) == 0: - # encoder_outs = self.encoders(xs_pad, None, None, None, None) - # xs_pad, masks = encoder_outs[0], encoder_outs[1] - # else: - # for layer_idx, encoder_layer in enumerate(self.encoders): - # encoder_outs = encoder_layer(xs_pad, None, None, None, None) - # xs_pad, masks = encoder_outs[0], encoder_outs[1] - # if layer_idx + 1 in self.interctc_layer_idx: - # encoder_out = xs_pad - - # # intermediate outputs are also normalized - # if self.normalize_before: - # encoder_out = self.after_norm(encoder_out) - - # intermediate_outs.append((layer_idx + 1, encoder_out)) - - # if self.interctc_use_conditioning: - # ctc_out = ctc.softmax(encoder_out) - # xs_pad = xs_pad + self.conditioning_layer(ctc_out) - - # if self.normalize_before: - # xs_pad = self.after_norm(xs_pad) - - # if len(intermediate_outs) > 0: - # return (xs_pad, intermediate_outs), None, None - # return xs_pad, ilens, None - - def forward_chunk(self, xs_pad: torch.Tensor, ilens: torch.Tensor,