Make Emotion2vec support onnx (#2359)

* Make emotion2vec exportable to onnx

* Make export_meta of emotion2vec consistence with other models

* Include layer norm in the exported onnx model
This commit is contained in:
takipipo 2025-01-16 09:33:23 +07:00 committed by GitHub
parent d4f13c2e44
commit 3530688e0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 2 deletions

1
.gitignore vendored
View File

@ -27,3 +27,4 @@ GPT-SoVITS*
modelscope_models
examples/aishell/llm_asr_nar/*
*egg-info
env/

View File

@ -2,8 +2,9 @@
from funasr import AutoModel
model = AutoModel(
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/emotion2vec_base",
hub="ms"
)
res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu
res = model.export(type="onnx", quantize=False, opset_version=13, device='cpu') # fp32 onnx-gpu
# res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu

View File

@ -0,0 +1,76 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
import torch.nn.functional as F
def export_rebuild_model(model, **kwargs):
model.device = kwargs.get("device")
# store original forward since self.extract_features is calling it
model._original_forward = model.forward
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
model.export_name = "emotion2vec"
return model
def export_forward(
self, x: torch.Tensor
):
with torch.no_grad():
if self.cfg.normalize:
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + 1e-5)
x = x.view(x.shape[0], -1)
# Call the original forward directly just like extract_features
# Cannot directly use self.extract_features since it is being replaced by export_forward
res = self._original_forward(
source=x,
padding_mask=None,
mask=False,
features_only=True,
remove_extra_tokens=True
)
x = res["x"]
return x
def export_dummy_inputs(self):
return (torch.randn(1, 16000),)
def export_input_names(self):
return ["input"]
def export_output_names(self):
return ["output"]
def export_dynamic_axes(self):
return {
"input": {
0: "batch_size",
1: "sequence_length",
},
"output": {0: "batch_size", 1: "sequence_length"},
}
def export_name(self):
return "emotion2vec"

View File

@ -265,3 +265,9 @@ class Emotion2vec(torch.nn.Module):
results.append(result_i)
return results, meta_data
def export(self, **kwargs):
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models