mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #205 from alibaba-damo-academy/dev_onnx
export bug fix
This commit is contained in:
commit
ef9fb63162
@ -61,7 +61,6 @@ class ConformerEncoder(nn.Module):
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
mask = self.prepare_mask(mask)
|
||||
if self.embed is None:
|
||||
|
||||
@ -54,6 +54,7 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
residual = tgt
|
||||
tgt = self.norm1(tgt)
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
|
||||
|
||||
@ -61,7 +61,7 @@ class EncoderLayerConformer(nn.Module):
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.feed_forward_macaron(x)
|
||||
x = residual + self.feed_forward_macaron(x) * 0.5
|
||||
|
||||
residual = x
|
||||
x = self.norm_mha(x)
|
||||
@ -81,7 +81,7 @@ class EncoderLayerConformer(nn.Module):
|
||||
|
||||
residual = x
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.feed_forward(x)
|
||||
x = residual + self.feed_forward(x) * 0.5
|
||||
|
||||
x = self.norm_final(x)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user