mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Make Emotion2vec support onnx (#2359)
* Make emotion2vec exportable to onnx * Make export_meta of emotion2vec consistence with other models * Include layer norm in the exported onnx model
This commit is contained in:
parent
d4f13c2e44
commit
3530688e0a
1
.gitignore
vendored
1
.gitignore
vendored
@ -27,3 +27,4 @@ GPT-SoVITS*
|
||||
modelscope_models
|
||||
examples/aishell/llm_asr_nar/*
|
||||
*egg-info
|
||||
env/
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/emotion2vec_base",
|
||||
hub="ms"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu
|
||||
res = model.export(type="onnx", quantize=False, opset_version=13, device='cpu') # fp32 onnx-gpu
|
||||
# res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu
|
||||
|
||||
76
funasr/models/emotion2vec/export_meta.py
Normal file
76
funasr/models/emotion2vec/export_meta.py
Normal file
@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import types
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def export_rebuild_model(model, **kwargs):
|
||||
model.device = kwargs.get("device")
|
||||
|
||||
# store original forward since self.extract_features is calling it
|
||||
model._original_forward = model.forward
|
||||
|
||||
model.forward = types.MethodType(export_forward, model)
|
||||
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
||||
model.export_input_names = types.MethodType(export_input_names, model)
|
||||
model.export_output_names = types.MethodType(export_output_names, model)
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
model.export_name = "emotion2vec"
|
||||
return model
|
||||
|
||||
|
||||
def export_forward(
|
||||
self, x: torch.Tensor
|
||||
):
|
||||
with torch.no_grad():
|
||||
if self.cfg.normalize:
|
||||
mean = torch.mean(x, dim=1, keepdim=True)
|
||||
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
|
||||
x = (x - mean) / torch.sqrt(var + 1e-5)
|
||||
x = x.view(x.shape[0], -1)
|
||||
|
||||
# Call the original forward directly just like extract_features
|
||||
# Cannot directly use self.extract_features since it is being replaced by export_forward
|
||||
res = self._original_forward(
|
||||
source=x,
|
||||
padding_mask=None,
|
||||
mask=False,
|
||||
features_only=True,
|
||||
remove_extra_tokens=True
|
||||
)
|
||||
|
||||
x = res["x"]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
return (torch.randn(1, 16000),)
|
||||
|
||||
|
||||
def export_input_names(self):
|
||||
return ["input"]
|
||||
|
||||
|
||||
def export_output_names(self):
|
||||
return ["output"]
|
||||
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
"input": {
|
||||
0: "batch_size",
|
||||
1: "sequence_length",
|
||||
},
|
||||
"output": {0: "batch_size", 1: "sequence_length"},
|
||||
}
|
||||
|
||||
|
||||
def export_name(self):
|
||||
return "emotion2vec"
|
||||
@ -265,3 +265,9 @@ class Emotion2vec(torch.nn.Module):
|
||||
results.append(result_i)
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def export(self, **kwargs):
|
||||
from .export_meta import export_rebuild_model
|
||||
|
||||
models = export_rebuild_model(model=self, **kwargs)
|
||||
return models
|
||||
|
||||
Loading…
Reference in New Issue
Block a user