mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update func cif_wo_hidden
This commit is contained in:
parent
74eb3b105d
commit
c73d1a8e81
@ -288,7 +288,7 @@ def cif_wo_hidden(alphas, threshold: float):
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
||||
integrate)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
|
||||
@ -499,7 +499,7 @@ def cif_wo_hidden(alphas, threshold):
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
||||
integrate)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
|
||||
@ -19,7 +19,7 @@ def cif_wo_hidden(alphas, threshold):
|
||||
list_fires.append(integrate)
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
||||
integrate)
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
|
||||
Loading…
Reference in New Issue
Block a user