funasr1.0 emotion2vec

This commit is contained in:
游雁 2024-01-08 16:40:43 +08:00
parent fb176404cf
commit e8590bb1e9
2 changed files with 18 additions and 3 deletions

View File

@ -5,7 +5,7 @@
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/emotion2vec_base") model = AutoModel(model="../modelscope_models/emotion2vec_base")
res = model(input="/Users/zhifu/Downloads/modelscope_models/emotion2vec_base/example/test.wav") res = model(input="../modelscope_models/emotion2vec_base/example/test.wav")
print(res) print(res)

View File

@ -1,5 +1,11 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
import logging import logging
import os
from functools import partial from functools import partial
import numpy as np import numpy as np
@ -21,7 +27,11 @@ from funasr.register import tables
@tables.register("model_classes", "Emotion2vec") @tables.register("model_classes", "Emotion2vec")
class Emotion2vec(nn.Module): class Emotion2vec(nn.Module):
"""
Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
https://arxiv.org/abs/2312.15185
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
@ -196,6 +206,9 @@ class Emotion2vec(nn.Module):
time2 = time.perf_counter() time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
results = [] results = []
output_dir = kwargs.get("output_dir")
if output_dir:
os.makedirs(output_dir, exist_ok=True)
for i, wav in enumerate(audio_sample_list): for i, wav in enumerate(audio_sample_list):
source = wav.to(device=kwargs["device"]) source = wav.to(device=kwargs["device"])
if self.cfg.normalize: if self.cfg.normalize:
@ -211,5 +224,7 @@ class Emotion2vec(nn.Module):
result_i = {"key": key[i], "feats": feats} result_i = {"key": key[i], "feats": feats}
results.append(result_i) results.append(result_i)
if output_dir:
np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
return results, meta_data return results, meta_data