|
|
@@ -102,8 +102,8 @@ class DefaultFrontend(AbsFrontend):
|
|
|
if input_stft.dim() == 4:
|
|
|
# h: (B, T, C, F) -> h: (B, T, F)
|
|
|
if self.training:
|
|
|
- if self.use_channel == None:
|
|
|
- input_stft = input_stft[:, :, 0, :]
|
|
|
+ if self.use_channel is not None:
|
|
|
+ input_stft = input_stft[:, :, self.use_channel, :]
|
|
|
else:
|
|
|
# Select 1ch randomly
|
|
|
ch = np.random.randint(input_stft.size(2))
|