mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
emotion2vec
This commit is contained in:
parent
3530688e0a
commit
23c6d67288
@ -21,13 +21,10 @@ def export_rebuild_model(model, **kwargs):
|
|||||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||||
model.export_name = types.MethodType(export_name, model)
|
model.export_name = types.MethodType(export_name, model)
|
||||||
|
|
||||||
model.export_name = "emotion2vec"
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def export_forward(
|
def export_forward(self, x: torch.Tensor):
|
||||||
self, x: torch.Tensor
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.cfg.normalize:
|
if self.cfg.normalize:
|
||||||
mean = torch.mean(x, dim=1, keepdim=True)
|
mean = torch.mean(x, dim=1, keepdim=True)
|
||||||
@ -38,13 +35,9 @@ def export_forward(
|
|||||||
# Call the original forward directly just like extract_features
|
# Call the original forward directly just like extract_features
|
||||||
# Cannot directly use self.extract_features since it is being replaced by export_forward
|
# Cannot directly use self.extract_features since it is being replaced by export_forward
|
||||||
res = self._original_forward(
|
res = self._original_forward(
|
||||||
source=x,
|
source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
|
||||||
padding_mask=None,
|
|
||||||
mask=False,
|
|
||||||
features_only=True,
|
|
||||||
remove_extra_tokens=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = res["x"]
|
x = res["x"]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user