From 00d3f31915f4f1885bd3bd8d128bed8436f7ff8d Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Wed, 12 Apr 2023 13:09:23 +0800 Subject: [PATCH] update loading cmvn_file --- funasr/models/frontend/wav_frontend.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 475a9398a..203f00e04 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -38,7 +38,7 @@ def load_cmvn(cmvn_file): return cmvn -def apply_cmvn(inputs, cmvn_file): # noqa +def apply_cmvn(inputs, cmvn): # noqa """ Apply CMVN with mvn data """ @@ -47,7 +47,6 @@ def apply_cmvn(inputs, cmvn_file): # noqa dtype = inputs.dtype frame, dim = inputs.shape - cmvn = load_cmvn(cmvn_file) means = np.tile(cmvn[0:1, :dim], (frame, 1)) vars = np.tile(cmvn[1:2, :dim], (frame, 1)) inputs += torch.from_numpy(means).type(dtype).to(device) @@ -111,6 +110,7 @@ class WavFrontend(AbsFrontend): self.dither = dither self.snip_edges = snip_edges self.upsacle_samples = upsacle_samples + self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) def output_size(self) -> int: return self.n_mels * self.lfr_m @@ -140,8 +140,8 @@ class WavFrontend(AbsFrontend): if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) - if self.cmvn_file is not None: - mat = apply_cmvn(mat, self.cmvn_file) + if self.cmvn is not None: + mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) @@ -194,8 +194,8 @@ class WavFrontend(AbsFrontend): mat = input[i, :input_lengths[i], :] if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) - if self.cmvn_file is not None: - mat = apply_cmvn(mat, self.cmvn_file) + if self.cmvn is not None: + mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length)