diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index c8cd30c0c..edcede5bf 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -157,7 +157,6 @@ class AutoModel: kwargs["device"] = device torch.set_num_threads(kwargs.get("ncpu", 4)) - # build tokenizer tokenizer = kwargs.get("tokenizer", None) diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py index 4802da07d..b93f93a88 100644 --- a/funasr/models/bicif_paraformer/model.py +++ b/funasr/models/bicif_paraformer/model.py @@ -367,7 +367,7 @@ class BiCifParaformer(Paraformer): else: self.make_pad_mask = sequence_mask(max_seq_len, flip=False) - self.forward = self._export_forward + self.forward = self.export_forward return self diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index 88ee86734..9f680fdc8 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -373,7 +373,7 @@ class CTTransformer(torch.nn.Module): encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export") self.encoder = encoder_class(self.encoder, onnx=is_onnx) - self.forward = self._export_forward + self.forward = self.export_forward return self diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py index 4752c4b5a..129cc95cc 100644 --- a/funasr/models/ct_transformer_streaming/model.py +++ b/funasr/models/ct_transformer_streaming/model.py @@ -182,7 +182,7 @@ class CTTransformerStreaming(CTTransformer): encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") self.encoder = encoder_class(self.encoder, onnx=is_onnx) - self.forward = self._export_forward + self.forward = self.export_forward return self diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index d06db2093..602cf23d3 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -647,7 +647,7 @@ class FsmnVADStreaming(nn.Module): is_onnx = kwargs.get("type", "onnx") == "onnx" encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") self.encoder = encoder_class(self.encoder, onnx=is_onnx) - self.forward = self._export_forward + self.forward = self.export_forward return self diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index f5f0e4e22..2e2a36e06 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -575,7 +575,7 @@ class Paraformer(torch.nn.Module): else: self.make_pad_mask = sequence_mask(max_seq_len, flip=False) - self.forward = self._export_forward + self.forward = self.export_forward return self diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index 63dba5d88..518fe9369 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -587,7 +587,6 @@ class ParaformerStreaming(Paraformer): else: self.make_pad_mask = sequence_mask(max_seq_len, flip=False) - self.forward = self._export_forward import copy import types @@ -595,7 +594,7 @@ class ParaformerStreaming(Paraformer): decoder_model = copy.copy(self) # encoder - encoder_model.forward = types.MethodType(ParaformerStreaming._export_encoder_forward, encoder_model) + encoder_model.forward = types.MethodType(ParaformerStreaming.export_encoder_forward, encoder_model) encoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_encoder_dummy_inputs, encoder_model) encoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_encoder_input_names, encoder_model) encoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_encoder_output_names, encoder_model) @@ -603,7 +602,7 @@ class ParaformerStreaming(Paraformer): encoder_model.export_name = types.MethodType(ParaformerStreaming.export_encoder_name, encoder_model) # decoder - decoder_model.forward = types.MethodType(ParaformerStreaming._export_decoder_forward, decoder_model) + decoder_model.forward = types.MethodType(ParaformerStreaming.export_decoder_forward, decoder_model) decoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_decoder_dummy_inputs, decoder_model) decoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_decoder_input_names, decoder_model) decoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_decoder_output_names, decoder_model)