From c73d1a8e81582b91a9bdd6e82fce2e84f8d9d94b Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Mon, 14 Aug 2023 19:31:55 +0800 Subject: [PATCH] update func cif_wo_hidden --- funasr/export/models/predictor/cif.py | 2 +- funasr/models/predictor/cif.py | 2 +- funasr/utils/timestamp_tools.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index dd5dd3669..03c443302 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -288,7 +288,7 @@ def cif_wo_hidden(alphas, threshold: float): fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c66af94e5..5f18c4d1d 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -499,7 +499,7 @@ def cif_wo_hidden(alphas, threshold): fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index c194179ab..6594273f2 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -19,7 +19,7 @@ def cif_wo_hidden(alphas, threshold): list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=alphas.device), + integrate - torch.ones([batch_size], device=alphas.device)*threshold, integrate) fires = torch.stack(list_fires, 1) return fires