This commit is contained in:
语帆 2024-02-28 14:25:00 +08:00
parent 0d32e02c79
commit ab4a31201c

View File

@ -85,8 +85,12 @@ class DefaultFrontend(nn.Module):
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(input_lengths, list):
input_lengths = torch.tensor(input_lengths)
if input.dtype == torch.float64:
input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)