diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index 332420898..da23f9c7c 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -62,17 +62,16 @@ def apply_lfr(inputs, lfr_m, lfr_n): left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) inputs = torch.vstack((left_padding, inputs)) T = T + (lfr_m - 1) // 2 - for i in range(T_lfr): - if lfr_m <= T - i * lfr_n: - LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1)) - else: # process last LFR frame - num_padding = lfr_m - (T - i * lfr_n) - frame = (inputs[i * lfr_n :]).view(-1) - for _ in range(num_padding): - frame = torch.hstack((frame, inputs[-1])) - LFR_inputs.append(frame) - LFR_outputs = torch.vstack(LFR_inputs) - return LFR_outputs.type(torch.float32) + feat_dim = inputs.shape[-1] + strides = (lfr_n * feat_dim, 1) + sizes = (T_lfr, lfr_m * feat_dim) + last_idx = (T - lfr_m) // lfr_n + 1 + num_padding = lfr_m - (T - last_idx * lfr_n) + if num_padding > 0: + num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx) + inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) + LFR_outputs = inputs.as_strided(sizes, strides) + return LFR_outputs.clone().type(torch.float32) @tables.register("frontend_classes", "wav_frontend") @@ -289,24 +288,24 @@ class WavFrontendOnline(nn.Module): np.ceil((T - (lfr_m - 1) // 2) / lfr_n) ) # minus the right context: (lfr_m - 1) // 2 splice_idx = T_lfr - for i in range(T_lfr): - if lfr_m <= T - i * lfr_n: - LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1)) - else: # process last LFR frame - if is_final: - num_padding = lfr_m - (T - i * lfr_n) - frame = (inputs[i * lfr_n :]).view(-1) - for _ in range(num_padding): - frame = torch.hstack((frame, inputs[-1])) - LFR_inputs.append(frame) - else: - # update splice_idx and break the circle - splice_idx = i - break + feat_dim = inputs.shape[-1] + ori_inputs = inputs + strides = (lfr_n * feat_dim, 1) + sizes = (T_lfr, lfr_m * feat_dim) + last_idx = (T - lfr_m) // lfr_n + 1 + num_padding = lfr_m - (T - last_idx * lfr_n) + if is_final: + if num_padding > 0: + num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx) + inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) + else: + if num_padding > 0: + sizes = (last_idx, lfr_m * feat_dim) + splice_idx = last_idx splice_idx = min(T - 1, splice_idx * lfr_n) - lfr_splice_cache = inputs[splice_idx:, :] - LFR_outputs = torch.vstack(LFR_inputs) - return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx + LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides) + lfr_splice_cache = ori_inputs[splice_idx:, :] + return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx @staticmethod def compute_frame_num(