feat: speed up fbank's lfr (#2246)

Co-authored-by: linjie.tang <linjie.tang@sophgo.com>
This commit is contained in:
Tang Linjie 2024-11-30 13:05:39 +08:00 committed by GitHub
parent ae49b2a8e1
commit 8b1be8c3cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,17 +62,16 @@ def apply_lfr(inputs, lfr_m, lfr_n):
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n :]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
feat_dim = inputs.shape[-1]
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if num_padding > 0:
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
LFR_outputs = inputs.as_strided(sizes, strides)
return LFR_outputs.clone().type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@ -289,24 +288,24 @@ class WavFrontendOnline(nn.Module):
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n :]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
else:
# update splice_idx and break the circle
splice_idx = i
break
feat_dim = inputs.shape[-1]
ori_inputs = inputs
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if is_final:
if num_padding > 0:
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
else:
if num_padding > 0:
sizes = (last_idx, lfr_m * feat_dim)
splice_idx = last_idx
splice_idx = min(T - 1, splice_idx * lfr_n)
lfr_splice_cache = inputs[splice_idx:, :]
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
lfr_splice_cache = ori_inputs[splice_idx:, :]
return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(