emotion2vec

This commit is contained in:
游雁 2025-01-16 11:25:36 +08:00
parent 3530688e0a
commit 23c6d67288

View File

@ -21,13 +21,10 @@ def export_rebuild_model(model, **kwargs):
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model) model.export_name = types.MethodType(export_name, model)
model.export_name = "emotion2vec"
return model return model
def export_forward( def export_forward(self, x: torch.Tensor):
self, x: torch.Tensor
):
with torch.no_grad(): with torch.no_grad():
if self.cfg.normalize: if self.cfg.normalize:
mean = torch.mean(x, dim=1, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True)
@ -38,13 +35,9 @@ def export_forward(
# Call the original forward directly just like extract_features # Call the original forward directly just like extract_features
# Cannot directly use self.extract_features since it is being replaced by export_forward # Cannot directly use self.extract_features since it is being replaced by export_forward
res = self._original_forward( res = self._original_forward(
source=x, source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
padding_mask=None,
mask=False,
features_only=True,
remove_extra_tokens=True
) )
x = res["x"] x = res["x"]
return x return x