mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fixbug for cif
This commit is contained in:
parent
fdac68e1d0
commit
93f9a424f2
@ -506,7 +506,10 @@ def cif_v1_export(hidden, alphas, threshold: float):
|
||||
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
||||
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
||||
|
||||
prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
# prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
|
||||
torch.float32
|
||||
) # cumsum precision degradation cause wrong result in extreme
|
||||
prefix_sum_floor = torch.floor(prefix_sum)
|
||||
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
||||
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
||||
@ -518,8 +521,8 @@ def cif_v1_export(hidden, alphas, threshold: float):
|
||||
fires[fire_idxs] = 1
|
||||
fires = fires + prefix_sum - prefix_sum_floor
|
||||
|
||||
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
||||
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
||||
|
||||
frames = prefix_sum_hidden[fire_idxs]
|
||||
shift_frames = torch.roll(frames, 1, dims=0)
|
||||
|
||||
@ -530,6 +533,7 @@ def cif_v1_export(hidden, alphas, threshold: float):
|
||||
shift_frames[shift_batch_idxs] = 0
|
||||
|
||||
remains = fires - torch.floor(fires)
|
||||
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
||||
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
||||
|
||||
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
||||
@ -537,8 +541,11 @@ def cif_v1_export(hidden, alphas, threshold: float):
|
||||
|
||||
frames = frames - shift_frames + shift_remain_frames - remain_frames
|
||||
|
||||
max_label_len = batch_len.max()
|
||||
# max_label_len = batch_len.max()
|
||||
max_label_len = alphas.sum(dim=-1)
|
||||
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
|
||||
|
||||
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
||||
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
||||
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
||||
frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
||||
@ -667,7 +674,10 @@ def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
|
||||
|
||||
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
||||
|
||||
prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
# prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
|
||||
torch.float32
|
||||
) # cumsum precision degradation cause wrong result in extreme
|
||||
prefix_sum_floor = torch.floor(prefix_sum)
|
||||
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
||||
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
||||
@ -689,6 +699,8 @@ def cif_v1(hidden, alphas, threshold):
|
||||
device = hidden.device
|
||||
dtype = hidden.dtype
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
||||
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
||||
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
||||
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
||||
|
||||
@ -702,6 +714,7 @@ def cif_v1(hidden, alphas, threshold):
|
||||
shift_frames[shift_batch_idxs] = 0
|
||||
|
||||
remains = fires - torch.floor(fires)
|
||||
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
||||
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
||||
|
||||
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
||||
@ -709,8 +722,12 @@ def cif_v1(hidden, alphas, threshold):
|
||||
|
||||
frames = frames - shift_frames + shift_remain_frames - remain_frames
|
||||
|
||||
max_label_len = batch_len.max()
|
||||
# max_label_len = batch_len.max()
|
||||
max_label_len = (
|
||||
torch.round(alphas.sum(-1)).int().max()
|
||||
) # torch.round to calculate the max length
|
||||
|
||||
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
||||
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
||||
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
||||
frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user