FunASR/funasr/models/emotion2vec/export_meta.py
2025-01-16 11:25:36 +08:00

70 lines
2.0 KiB
Python

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
import torch.nn.functional as F
def export_rebuild_model(model, **kwargs):
model.device = kwargs.get("device")
# store original forward since self.extract_features is calling it
model._original_forward = model.forward
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
return model
def export_forward(self, x: torch.Tensor):
with torch.no_grad():
if self.cfg.normalize:
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + 1e-5)
x = x.view(x.shape[0], -1)
# Call the original forward directly just like extract_features
# Cannot directly use self.extract_features since it is being replaced by export_forward
res = self._original_forward(
source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
)
x = res["x"]
return x
def export_dummy_inputs(self):
return (torch.randn(1, 16000),)
def export_input_names(self):
return ["input"]
def export_output_names(self):
return ["output"]
def export_dynamic_axes(self):
return {
"input": {
0: "batch_size",
1: "sequence_length",
},
"output": {0: "batch_size", 1: "sequence_length"},
}
def export_name(self):
return "emotion2vec"