This commit is contained in:
语帆 2024-02-28 15:21:32 +08:00
parent 19103386dc
commit eb92e79fb9

View File

@ -47,7 +47,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
import pdb
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@ -573,7 +573,7 @@ class ConformerEncoder(nn.Module):
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
pdb.set_trace()
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
@ -601,12 +601,12 @@ class ConformerEncoder(nn.Module):
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
pdb.set_trace()
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
pdb.set_trace()
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None