FunASR/funasr/models/ct_transformer/export_meta.py
zhifu gao 675b4605e8
Dev gzf llm (#1506)
* update

* update

* update

* update onnx

* update with main (#1492)

* contextual&seaco ONNX export (#1481)

* contextual&seaco ONNX export

* update ContextualEmbedderExport2

* update ContextualEmbedderExport2

* update code

* onnx (#1482)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix

* onnx

* update export calling

* update codes

* remove useless code

* update code

---------

Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>

* acknowledge

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>

* update onnx

* update onnx

* train update

* train update

* train update

* train update

* punc update

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
2024-03-15 21:14:08 +08:00

67 lines
2.0 KiB
Python

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
return model
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def export_input_names(self):
return ['inputs', 'text_lengths']
def export_output_names(self):
return ['logits']
def export_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
def export_name(self):
return "model.onnx"