mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
emotion2vec
This commit is contained in:
parent
3530688e0a
commit
23c6d67288
@ -21,13 +21,10 @@ def export_rebuild_model(model, **kwargs):
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
model.export_name = "emotion2vec"
|
||||
return model
|
||||
|
||||
|
||||
def export_forward(
|
||||
self, x: torch.Tensor
|
||||
):
|
||||
def export_forward(self, x: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
if self.cfg.normalize:
|
||||
mean = torch.mean(x, dim=1, keepdim=True)
|
||||
@ -38,13 +35,9 @@ def export_forward(
|
||||
# Call the original forward directly just like extract_features
|
||||
# Cannot directly use self.extract_features since it is being replaced by export_forward
|
||||
res = self._original_forward(
|
||||
source=x,
|
||||
padding_mask=None,
|
||||
mask=False,
|
||||
features_only=True,
|
||||
remove_extra_tokens=True
|
||||
source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
|
||||
)
|
||||
|
||||
|
||||
x = res["x"]
|
||||
|
||||
return x
|
||||
|
||||
Loading…
Reference in New Issue
Block a user