diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 57c5976f1..7a6425be3 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -90,7 +90,9 @@ class WavFrontend(AbsFrontend): filter_length_max: int = -1, lfr_m: int = 1, lfr_n: int = 1, - dither: float = 1.0 + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, ): assert check_argument_types() super().__init__() @@ -105,6 +107,8 @@ class WavFrontend(AbsFrontend): self.lfr_n = lfr_n self.cmvn_file = cmvn_file self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples def output_size(self) -> int: return self.n_mels * self.lfr_m @@ -119,7 +123,8 @@ class WavFrontend(AbsFrontend): for i in range(batch_size): waveform_length = input_lengths[i] waveform = input[i][:waveform_length] - waveform = waveform * (1 << 15) + if self.upsacle_samples: + waveform = waveform * (1 << 15) waveform = waveform.unsqueeze(0) mat = kaldi.fbank(waveform, num_mel_bins=self.n_mels, @@ -128,7 +133,8 @@ class WavFrontend(AbsFrontend): dither=self.dither, energy_floor=0.0, window_type=self.window, - sample_frequency=self.fs) + sample_frequency=self.fs, + snip_edges=self.snip_edges) if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n)