mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #1791 from dtlzhuangz/zhuangzhong_dev
Accelerate cif
This commit is contained in:
commit
7a9c0414b6
@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module):
|
||||
hidden, alphas, token_num, mask=mask
|
||||
)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
||||
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module):
|
||||
hidden, alphas, token_num, mask=None
|
||||
)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
@ -449,7 +449,7 @@ class CifPredictorV2Export(torch.nn.Module):
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
|
||||
acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
@ -494,7 +494,60 @@ class CifPredictorV2Export(torch.nn.Module):
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
@torch.jit.script
|
||||
def cif_v1_export(hidden, alphas, threshold: float):
|
||||
device = hidden.device
|
||||
dtype = hidden.dtype
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
|
||||
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
||||
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
||||
|
||||
prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
prefix_sum_floor = torch.floor(prefix_sum)
|
||||
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
||||
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
||||
|
||||
dislocation_prefix_sum_floor[:, 0] = 0
|
||||
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
|
||||
|
||||
fire_idxs = dislocation_diff > 0
|
||||
fires[fire_idxs] = 1
|
||||
fires = fires + prefix_sum - prefix_sum_floor
|
||||
|
||||
prefix_sum_hidden = torch.cumsum(
|
||||
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
||||
)
|
||||
|
||||
frames = prefix_sum_hidden[fire_idxs]
|
||||
shift_frames = torch.roll(frames, 1, dims=0)
|
||||
|
||||
batch_len = fire_idxs.sum(1)
|
||||
batch_idxs = torch.cumsum(batch_len, dim=0)
|
||||
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
|
||||
shift_batch_idxs[0] = 0
|
||||
shift_frames[shift_batch_idxs] = 0
|
||||
|
||||
remains = fires - torch.floor(fires)
|
||||
remain_frames = (
|
||||
remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
||||
)
|
||||
|
||||
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
||||
shift_remain_frames[shift_batch_idxs] = 0
|
||||
|
||||
frames = frames - shift_frames + shift_remain_frames - remain_frames
|
||||
|
||||
max_label_len = batch_len.max()
|
||||
|
||||
frame_fires = torch.zeros(
|
||||
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
||||
frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
||||
frame_fires[frame_fires_idxs] = frames
|
||||
return frame_fires, fires
|
||||
|
||||
@torch.jit.script
|
||||
def cif_export(hidden, alphas, threshold: float):
|
||||
@ -608,6 +661,74 @@ def cif(hidden, alphas, threshold):
|
||||
return torch.stack(list_ls, 0), fires
|
||||
|
||||
|
||||
def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
|
||||
batch_size, len_time = alphas.size()
|
||||
device = alphas.device
|
||||
dtype = alphas.dtype
|
||||
|
||||
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
|
||||
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
||||
|
||||
prefix_sum = torch.cumsum(alphas, dim=1)
|
||||
prefix_sum_floor = torch.floor(prefix_sum)
|
||||
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
||||
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
||||
|
||||
dislocation_prefix_sum_floor[:, 0] = 0
|
||||
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
|
||||
|
||||
fire_idxs = dislocation_diff > 0
|
||||
fires[fire_idxs] = 1
|
||||
fires = fires + prefix_sum - prefix_sum_floor
|
||||
if return_fire_idxs:
|
||||
return fires, fire_idxs
|
||||
return fires
|
||||
|
||||
|
||||
def cif_v1(hidden, alphas, threshold):
|
||||
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
|
||||
|
||||
device = hidden.device
|
||||
dtype = hidden.dtype
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
frames = torch.zeros(batch_size, len_time, hidden_size,
|
||||
dtype=dtype, device=device)
|
||||
prefix_sum_hidden = torch.cumsum(
|
||||
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
||||
)
|
||||
|
||||
frames = prefix_sum_hidden[fire_idxs]
|
||||
shift_frames = torch.roll(frames, 1, dims=0)
|
||||
|
||||
batch_len = fire_idxs.sum(1)
|
||||
batch_idxs = torch.cumsum(batch_len, dim=0)
|
||||
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
|
||||
shift_batch_idxs[0] = 0
|
||||
shift_frames[shift_batch_idxs] = 0
|
||||
|
||||
remains = fires - torch.floor(fires)
|
||||
remain_frames = (
|
||||
remains[fire_idxs].unsqueeze(-1).tile((1,
|
||||
hidden_size)) * hidden[fire_idxs]
|
||||
)
|
||||
|
||||
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
||||
shift_remain_frames[shift_batch_idxs] = 0
|
||||
|
||||
frames = frames - shift_frames + shift_remain_frames - remain_frames
|
||||
|
||||
max_label_len = batch_len.max()
|
||||
|
||||
frame_fires = torch.zeros(
|
||||
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
||||
frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
||||
frame_fires[frame_fires_idxs] = frames
|
||||
return frame_fires, fires
|
||||
|
||||
|
||||
def cif_wo_hidden(alphas, threshold):
|
||||
batch_size, len_time = alphas.size()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user