diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py index 364c8bbb9..c4bdbd774 100644 --- a/funasr/frontends/default.py +++ b/funasr/frontends/default.py @@ -85,8 +85,12 @@ class DefaultFrontend(nn.Module): return self.n_mels def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor + self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list] ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(input_lengths, list): + input_lengths = torch.tensor(input_lengths) + if input.dtype == torch.float64: + input = input.float() # 1. Domain-conversion: e.g. Stft: time -> time-freq if self.stft is not None: input_stft, feats_lens = self._compute_stft(input, input_lengths)