From 3530688e0a1b1dfbb22dcd3324db97be5bbc0d9b Mon Sep 17 00:00:00 2001 From: takipipo <69394786+takipipo@users.noreply.github.com> Date: Thu, 16 Jan 2025 09:33:23 +0700 Subject: [PATCH] Make Emotion2vec support onnx (#2359) * Make emotion2vec exportable to onnx * Make export_meta of emotion2vec consistence with other models * Include layer norm in the exported onnx model --- .gitignore | 1 + export.py | 5 +- funasr/models/emotion2vec/export_meta.py | 76 ++++++++++++++++++++++++ funasr/models/emotion2vec/model.py | 6 ++ 4 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 funasr/models/emotion2vec/export_meta.py diff --git a/.gitignore b/.gitignore index 37802ac4f..8a2584693 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ GPT-SoVITS* modelscope_models examples/aishell/llm_asr_nar/* *egg-info +env/ diff --git a/export.py b/export.py index 2ea6dea28..a891e60ae 100644 --- a/export.py +++ b/export.py @@ -2,8 +2,9 @@ from funasr import AutoModel model = AutoModel( - model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + model="iic/emotion2vec_base", + hub="ms" ) -res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu +res = model.export(type="onnx", quantize=False, opset_version=13, device='cpu') # fp32 onnx-gpu # res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu diff --git a/funasr/models/emotion2vec/export_meta.py b/funasr/models/emotion2vec/export_meta.py new file mode 100644 index 000000000..2954e5fc6 --- /dev/null +++ b/funasr/models/emotion2vec/export_meta.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import types +import torch +import torch.nn.functional as F + + +def export_rebuild_model(model, **kwargs): + model.device = kwargs.get("device") + + # store original forward since self.extract_features is calling it + model._original_forward = model.forward + + model.forward = types.MethodType(export_forward, model) + model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) + model.export_input_names = types.MethodType(export_input_names, model) + model.export_output_names = types.MethodType(export_output_names, model) + model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) + model.export_name = types.MethodType(export_name, model) + + model.export_name = "emotion2vec" + return model + + +def export_forward( + self, x: torch.Tensor +): + with torch.no_grad(): + if self.cfg.normalize: + mean = torch.mean(x, dim=1, keepdim=True) + var = torch.var(x, dim=1, keepdim=True, unbiased=False) + x = (x - mean) / torch.sqrt(var + 1e-5) + x = x.view(x.shape[0], -1) + + # Call the original forward directly just like extract_features + # Cannot directly use self.extract_features since it is being replaced by export_forward + res = self._original_forward( + source=x, + padding_mask=None, + mask=False, + features_only=True, + remove_extra_tokens=True + ) + + x = res["x"] + + return x + + +def export_dummy_inputs(self): + return (torch.randn(1, 16000),) + + +def export_input_names(self): + return ["input"] + + +def export_output_names(self): + return ["output"] + + +def export_dynamic_axes(self): + return { + "input": { + 0: "batch_size", + 1: "sequence_length", + }, + "output": {0: "batch_size", 1: "sequence_length"}, + } + + +def export_name(self): + return "emotion2vec" diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py index d18e1844c..75ede4c0c 100644 --- a/funasr/models/emotion2vec/model.py +++ b/funasr/models/emotion2vec/model.py @@ -265,3 +265,9 @@ class Emotion2vec(torch.nn.Module): results.append(result_i) return results, meta_data + + def export(self, **kwargs): + from .export_meta import export_rebuild_model + + models = export_rebuild_model(model=self, **kwargs) + return models