general punc model conversion onnx

This commit is contained in:
九耳 2023-03-29 16:48:57 +08:00
parent 1f8b46402c
commit 6ebd267648
2 changed files with 171 additions and 1 deletions

View File

@ -174,7 +174,10 @@ class ModelExport:
json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
mode = config_data['model']['model_config']['mode']
if config_data['task'] == "punctuation":
mode = config_data['model']['punc_model_config']['mode']
else:
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
config = os.path.join(model_dir, 'config.yaml')
@ -195,6 +198,13 @@ class ModelExport:
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
elif mode.startswith('punc'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
self._export(model, tag_name)

View File

@ -0,0 +1,160 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
# from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['input', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'input': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['input', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'input': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}