mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
feat: speed up fbank's lfr (#2246)
Co-authored-by: linjie.tang <linjie.tang@sophgo.com>
This commit is contained in:
parent
ae49b2a8e1
commit
8b1be8c3cb
@ -62,17 +62,16 @@ def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n :]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
feat_dim = inputs.shape[-1]
|
||||
strides = (lfr_n * feat_dim, 1)
|
||||
sizes = (T_lfr, lfr_m * feat_dim)
|
||||
last_idx = (T - lfr_m) // lfr_n + 1
|
||||
num_padding = lfr_m - (T - last_idx * lfr_n)
|
||||
if num_padding > 0:
|
||||
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
|
||||
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
|
||||
LFR_outputs = inputs.as_strided(sizes, strides)
|
||||
return LFR_outputs.clone().type(torch.float32)
|
||||
|
||||
|
||||
@tables.register("frontend_classes", "wav_frontend")
|
||||
@ -289,24 +288,24 @@ class WavFrontendOnline(nn.Module):
|
||||
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
||||
) # minus the right context: (lfr_m - 1) // 2
|
||||
splice_idx = T_lfr
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
if is_final:
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n :]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
else:
|
||||
# update splice_idx and break the circle
|
||||
splice_idx = i
|
||||
break
|
||||
feat_dim = inputs.shape[-1]
|
||||
ori_inputs = inputs
|
||||
strides = (lfr_n * feat_dim, 1)
|
||||
sizes = (T_lfr, lfr_m * feat_dim)
|
||||
last_idx = (T - lfr_m) // lfr_n + 1
|
||||
num_padding = lfr_m - (T - last_idx * lfr_n)
|
||||
if is_final:
|
||||
if num_padding > 0:
|
||||
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
|
||||
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
|
||||
else:
|
||||
if num_padding > 0:
|
||||
sizes = (last_idx, lfr_m * feat_dim)
|
||||
splice_idx = last_idx
|
||||
splice_idx = min(T - 1, splice_idx * lfr_n)
|
||||
lfr_splice_cache = inputs[splice_idx:, :]
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
|
||||
LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
|
||||
lfr_splice_cache = ori_inputs[splice_idx:, :]
|
||||
return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx
|
||||
|
||||
@staticmethod
|
||||
def compute_frame_num(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user