mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Emo2Vec限定选择的情感类别 (#1730)
* 限定选择的情感类别 * 使用none来禁用情感标签输出 * 修改输出接口 * 使用unuse来禁用token --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
This commit is contained in:
parent
2f27b16555
commit
2f7dcbad90
@ -249,10 +249,17 @@ class Emotion2vec(torch.nn.Module):
|
||||
if self.proj:
|
||||
x = x.mean(dim=1)
|
||||
x = self.proj(x)
|
||||
for idx, lab in enumerate(labels):
|
||||
x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx]
|
||||
x = torch.softmax(x, dim=-1)
|
||||
scores = x[0].tolist()
|
||||
|
||||
result_i = {"key": key[i], "labels": labels, "scores": scores}
|
||||
select_label = [lb for lb in labels if not lb.startswith("unuse")]
|
||||
select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")]
|
||||
|
||||
# result_i = {"key": key[i], "labels": labels, "scores": scores}
|
||||
result_i = {"key": key[i], "labels": select_label, "scores": select_score}
|
||||
|
||||
if extract_embedding:
|
||||
result_i["feats"] = feats
|
||||
results.append(result_i)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user