From 93f9a424f2bc0607d31ef66b0c7c58dfac15ce25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 24 Jun 2024 10:07:31 +0800 Subject: [PATCH] fixbug for cif --- funasr/models/paraformer/cif_predictor.py | 27 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 83ca464da..535131f39 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -506,7 +506,10 @@ def cif_v1_export(hidden, alphas, threshold: float): frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) - prefix_sum = torch.cumsum(alphas, dim=1) + # prefix_sum = torch.cumsum(alphas, dim=1) + prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to( + torch.float32 + ) # cumsum precision degradation cause wrong result in extreme prefix_sum_floor = torch.floor(prefix_sum) dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) @@ -518,8 +521,8 @@ def cif_v1_export(hidden, alphas, threshold: float): fires[fire_idxs] = 1 fires = fires + prefix_sum - prefix_sum_floor + # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) - frames = prefix_sum_hidden[fire_idxs] shift_frames = torch.roll(frames, 1, dims=0) @@ -530,6 +533,7 @@ def cif_v1_export(hidden, alphas, threshold: float): shift_frames[shift_batch_idxs] = 0 remains = fires - torch.floor(fires) + # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] shift_remain_frames = torch.roll(remain_frames, 1, dims=0) @@ -537,8 +541,11 @@ def cif_v1_export(hidden, alphas, threshold: float): frames = frames - shift_frames + shift_remain_frames - remain_frames - max_label_len = batch_len.max() + # max_label_len = batch_len.max() + max_label_len = alphas.sum(dim=-1) + max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64) + # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) frame_fires_idxs = indices < batch_len.unsqueeze(1) @@ -667,7 +674,10 @@ def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False): fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) - prefix_sum = torch.cumsum(alphas, dim=1) + # prefix_sum = torch.cumsum(alphas, dim=1) + prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to( + torch.float32 + ) # cumsum precision degradation cause wrong result in extreme prefix_sum_floor = torch.floor(prefix_sum) dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) @@ -689,6 +699,8 @@ def cif_v1(hidden, alphas, threshold): device = hidden.device dtype = hidden.dtype batch_size, len_time, hidden_size = hidden.size() + # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) + # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) @@ -702,6 +714,7 @@ def cif_v1(hidden, alphas, threshold): shift_frames[shift_batch_idxs] = 0 remains = fires - torch.floor(fires) + # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] shift_remain_frames = torch.roll(remain_frames, 1, dims=0) @@ -709,8 +722,12 @@ def cif_v1(hidden, alphas, threshold): frames = frames - shift_frames + shift_remain_frames - remain_frames - max_label_len = batch_len.max() + # max_label_len = batch_len.max() + max_label_len = ( + torch.round(alphas.sum(-1)).int().max() + ) # torch.round to calculate the max length + # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device) indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) frame_fires_idxs = indices < batch_len.unsqueeze(1)