From ab4a31201c218b212ac52cbd529024c5858a9f87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AF=AD=E5=B8=86?= Date: Wed, 28 Feb 2024 14:25:00 +0800 Subject: [PATCH] test --- funasr/frontends/default.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py index 364c8bbb9..c4bdbd774 100644 --- a/funasr/frontends/default.py +++ b/funasr/frontends/default.py @@ -85,8 +85,12 @@ class DefaultFrontend(nn.Module): return self.n_mels def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor + self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list] ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(input_lengths, list): + input_lengths = torch.tensor(input_lengths) + if input.dtype == torch.float64: + input = input.float() # 1. Domain-conversion: e.g. Stft: time -> time-freq if self.stft is not None: input_stft, feats_lens = self._compute_stft(input, input_lengths)