diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py index 2e1b0c454..c4dd7c537 100644 --- a/funasr/models/frontend/default.py +++ b/funasr/models/frontend/default.py @@ -102,8 +102,8 @@ class DefaultFrontend(AbsFrontend): if input_stft.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: - if self.use_channel == None: - input_stft = input_stft[:, :, 0, :] + if self.use_channel is not None: + input_stft = input_stft[:, :, self.use_channel, :] else: # Select 1ch randomly ch = np.random.randint(input_stft.size(2))