mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf (#1469)
* qwenaudio qwenaudiochat * qwenaudio qwenaudiochat * whisper * whisper * llm * llm * llm * llm * llm * llm * llm * llm * export onnx * export onnx * export onnx * dingding * dingding * llm * doc * onnx * onnx * onnx
This commit is contained in:
parent
e847f85a14
commit
cc59310dbf
@ -157,7 +157,6 @@ class AutoModel:
|
||||
kwargs["device"] = device
|
||||
|
||||
torch.set_num_threads(kwargs.get("ncpu", 4))
|
||||
|
||||
|
||||
# build tokenizer
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
|
||||
@ -367,7 +367,7 @@ class BiCifParaformer(Paraformer):
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
self.forward = self._export_forward
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ class CTTransformer(torch.nn.Module):
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
|
||||
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
|
||||
|
||||
self.forward = self._export_forward
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@ -182,7 +182,7 @@ class CTTransformerStreaming(CTTransformer):
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
|
||||
|
||||
self.forward = self._export_forward
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@ -647,7 +647,7 @@ class FsmnVADStreaming(nn.Module):
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
|
||||
self.forward = self._export_forward
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@ -575,7 +575,7 @@ class Paraformer(torch.nn.Module):
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
self.forward = self._export_forward
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@ -587,7 +587,6 @@ class ParaformerStreaming(Paraformer):
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
self.forward = self._export_forward
|
||||
|
||||
import copy
|
||||
import types
|
||||
@ -595,7 +594,7 @@ class ParaformerStreaming(Paraformer):
|
||||
decoder_model = copy.copy(self)
|
||||
|
||||
# encoder
|
||||
encoder_model.forward = types.MethodType(ParaformerStreaming._export_encoder_forward, encoder_model)
|
||||
encoder_model.forward = types.MethodType(ParaformerStreaming.export_encoder_forward, encoder_model)
|
||||
encoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_encoder_dummy_inputs, encoder_model)
|
||||
encoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_encoder_input_names, encoder_model)
|
||||
encoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_encoder_output_names, encoder_model)
|
||||
@ -603,7 +602,7 @@ class ParaformerStreaming(Paraformer):
|
||||
encoder_model.export_name = types.MethodType(ParaformerStreaming.export_encoder_name, encoder_model)
|
||||
|
||||
# decoder
|
||||
decoder_model.forward = types.MethodType(ParaformerStreaming._export_decoder_forward, decoder_model)
|
||||
decoder_model.forward = types.MethodType(ParaformerStreaming.export_decoder_forward, decoder_model)
|
||||
decoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_decoder_dummy_inputs, decoder_model)
|
||||
decoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_decoder_input_names, decoder_model)
|
||||
decoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_decoder_output_names, decoder_model)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user