diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index dd5dd3669..03c443302 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -288,7 +288,7 @@ def cif_wo_hidden(alphas, threshold: float): fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c66af94e5..5f18c4d1d 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -499,7 +499,7 @@ def cif_wo_hidden(alphas, threshold): fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index c194179ab..6594273f2 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -19,7 +19,7 @@ def cif_wo_hidden(alphas, threshold): list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) return fires