mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0.4 emotion2vec finetuned
This commit is contained in:
parent
f1c1cb0773
commit
cf7f9a06c8
@ -5,8 +5,9 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.4")
|
||||
# model="damo/emotion2vec_base"
|
||||
model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
|
||||
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance")
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
|
||||
print(res)
|
||||
@ -1,5 +1,6 @@
|
||||
|
||||
model="damo/emotion2vec_base"
|
||||
#model="damo/emotion2vec_base"
|
||||
model="iic/emotion2vec_base_finetuned"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
@ -7,4 +8,5 @@ python funasr/bin/inference.py \
|
||||
+model_revision=${model_revision} \
|
||||
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
|
||||
+output_dir="./outputs/debug" \
|
||||
+extract_embedding=False \
|
||||
+device="cpu" \
|
||||
|
||||
@ -93,7 +93,10 @@ class Emotion2vec(torch.nn.Module):
|
||||
if cfg.get("layer_norm_first"):
|
||||
self.norm = make_layer_norm(cfg.get("embed_dim"))
|
||||
|
||||
|
||||
vocab_size = kwargs.get("vocab_size", -1)
|
||||
self.proj = None
|
||||
if vocab_size > 0:
|
||||
self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -204,6 +207,9 @@ class Emotion2vec(torch.nn.Module):
|
||||
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
|
||||
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
|
||||
granularity = kwargs.get("granularity", "utterance")
|
||||
extract_embedding = kwargs.get("extract_embedding", True)
|
||||
if self.proj is None:
|
||||
extract_embedding = True
|
||||
meta_data = {}
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
@ -211,6 +217,8 @@ class Emotion2vec(torch.nn.Module):
|
||||
data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
|
||||
|
||||
results = []
|
||||
output_dir = kwargs.get("output_dir")
|
||||
if output_dir:
|
||||
@ -222,15 +230,28 @@ class Emotion2vec(torch.nn.Module):
|
||||
source = source.view(1, -1)
|
||||
|
||||
feats = self.extract_features(source, padding_mask=None)
|
||||
x = feats['x']
|
||||
feats = feats['x'].squeeze(0).cpu().numpy()
|
||||
if granularity == 'frame':
|
||||
feats = feats
|
||||
elif granularity == 'utterance':
|
||||
feats = np.mean(feats, axis=0)
|
||||
|
||||
result_i = {"key": key[i], "feats": feats}
|
||||
results.append(result_i)
|
||||
if output_dir:
|
||||
if output_dir and extract_embedding:
|
||||
np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
|
||||
|
||||
labels = tokenizer.token_list if tokenizer is not None else []
|
||||
scores = []
|
||||
if self.proj:
|
||||
x = x.mean(dim=1)
|
||||
x = self.proj(x)
|
||||
x = torch.softmax(x, dim=-1)
|
||||
scores = x[0].tolist()
|
||||
|
||||
result_i = {"key": key[i], "labels": labels, "scores": scores}
|
||||
if extract_embedding:
|
||||
result_i["feats"] = feats
|
||||
results.append(result_i)
|
||||
|
||||
|
||||
return results, meta_data
|
||||
Loading…
Reference in New Issue
Block a user