diff --git a/funasr/models/extract_tokens/model.py b/funasr/models/extract_tokens/model.py index e383f0d44..faf088bad 100644 --- a/funasr/models/extract_tokens/model.py +++ b/funasr/models/extract_tokens/model.py @@ -1082,7 +1082,7 @@ class SenseVoiceQuantizedEncoderPitch(nn.Module): self.normalized_quant_input = normalized_quant_input self.quantizer = self.build_quantizer(quantizer_config) - self.pitch_predictor = torch.Linear(units, 1) + self.pitch_predictor = torch.nn.Linear(units, 1) self.pitch_act = torch.nn.ReLU() def build_quantizer(self, vq_config):