mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
c238436e07
commit
a6889a3170
@ -280,8 +280,8 @@ class NeatContextualParaformer(Paraformer):
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
|
||||
)
|
||||
decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2]
|
||||
|
||||
decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
'''
|
||||
if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
|
||||
ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
|
||||
attn_non_blank = attn[:,:,:,:-1]
|
||||
@ -289,7 +289,9 @@ class NeatContextualParaformer(Paraformer):
|
||||
loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
|
||||
else:
|
||||
loss_ideal = None
|
||||
|
||||
'''
|
||||
loss_ideal = None
|
||||
|
||||
if decoder_out_1st is None:
|
||||
decoder_out_1st = decoder_out
|
||||
# 2. Compute attention loss
|
||||
|
||||
Loading…
Reference in New Issue
Block a user