This commit is contained in:
yhliang 2023-05-18 11:27:24 +08:00
parent db77a41e29
commit 1e650fac78

View File

@ -102,8 +102,8 @@ class DefaultFrontend(AbsFrontend):
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel == None:
input_stft = input_stft[:, :, 0, :]
if self.use_channel is not None:
input_stft = input_stft[:, :, self.use_channel, :]
else:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))