diff --git a/funasr/models/emotion2vec/export_meta.py b/funasr/models/emotion2vec/export_meta.py index 2954e5fc6..f0c5f434b 100644 --- a/funasr/models/emotion2vec/export_meta.py +++ b/funasr/models/emotion2vec/export_meta.py @@ -21,13 +21,10 @@ def export_rebuild_model(model, **kwargs): model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) model.export_name = types.MethodType(export_name, model) - model.export_name = "emotion2vec" return model -def export_forward( - self, x: torch.Tensor -): +def export_forward(self, x: torch.Tensor): with torch.no_grad(): if self.cfg.normalize: mean = torch.mean(x, dim=1, keepdim=True) @@ -38,13 +35,9 @@ def export_forward( # Call the original forward directly just like extract_features # Cannot directly use self.extract_features since it is being replaced by export_forward res = self._original_forward( - source=x, - padding_mask=None, - mask=False, - features_only=True, - remove_extra_tokens=True + source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True ) - + x = res["x"] return x