mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Compare commits
16 Commits
81d8cf72a9
...
e06a737112
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e06a737112 | ||
|
|
4402e95b0f | ||
|
|
f5051c55cd | ||
|
|
bd4340fdfc | ||
|
|
dc06a80dbc | ||
|
|
c529ac9b45 | ||
|
|
9b423d3d6a | ||
|
|
3ad0599437 | ||
|
|
b407b4c345 | ||
|
|
2add00c614 | ||
|
|
be015ec75d | ||
|
|
0528806aa7 | ||
|
|
ebb0940f2b | ||
|
|
b31592acd7 | ||
|
|
b1836414b5 | ||
|
|
fa74a6e26c |
10
export.py
10
export.py
@ -1,10 +0,0 @@
|
||||
# method2, inference from local path
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="iic/emotion2vec_base",
|
||||
hub="ms"
|
||||
)
|
||||
|
||||
res = model.export(type="onnx", quantize=False, opset_version=13, device='cpu') # fp32 onnx-gpu
|
||||
# res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu
|
||||
@ -108,6 +108,24 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
||||
return key_list, data_list
|
||||
|
||||
|
||||
def distribute_emotion(sentence_list, ser_time_list):
|
||||
ser_time_list = [(st * 1000, ed * 1000, emotion) for st, ed, emotion in ser_time_list]
|
||||
for d in sentence_list:
|
||||
sentence_start = d['start']
|
||||
sentence_end = d['end']
|
||||
sentence_emotion = "EMO_UNKNOWN"
|
||||
max_overlap = 0
|
||||
for st, ed, emotion in ser_time_list:
|
||||
overlap = max(min(sentence_end, ed) - max(sentence_start, st), 0)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
sentence_emotion = emotion
|
||||
if overlap > 0 and sentence_emotion == emotion:
|
||||
max_overlap += overlap
|
||||
d['emotion'] = sentence_emotion
|
||||
return sentence_list
|
||||
|
||||
|
||||
class AutoModel:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -161,7 +179,11 @@ class AutoModel:
|
||||
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
||||
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
|
||||
self.spk_mode = spk_mode
|
||||
|
||||
ser_model = kwargs.get("ser_model", None)
|
||||
ser_kwargs = {} if kwargs.get("ser_kwargs", {}) is None else kwargs.get("ser_kwargs", {})
|
||||
if ser_model is not None:
|
||||
logging.info("Building SER model.")
|
||||
ser_model, ser_kwargs = self.build_model(**ser_kwargs)
|
||||
self.kwargs = kwargs
|
||||
self.model = model
|
||||
self.vad_model = vad_model
|
||||
@ -170,6 +192,8 @@ class AutoModel:
|
||||
self.punc_kwargs = punc_kwargs
|
||||
self.spk_model = spk_model
|
||||
self.spk_kwargs = spk_kwargs
|
||||
self.ser_model = ser_model
|
||||
self.ser_kwargs = ser_kwargs
|
||||
self.model_path = kwargs.get("model_path")
|
||||
|
||||
@staticmethod
|
||||
@ -502,6 +526,16 @@ class AutoModel:
|
||||
speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg
|
||||
)
|
||||
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
||||
if self.ser_model is not None:
|
||||
ser_res = self.inference(speech_b, input_len=None, model=self.ser_model,
|
||||
kwargs=self.ser_kwargs, **cfg)
|
||||
if "SenseVoiceSmall" in kwargs.get("ser_model", None):
|
||||
results[_b]["ser_type"] = [i['text'].split("|><|")[1] for i in ser_res]
|
||||
elif "emotion2vec" in kwargs.get("ser_model", None):
|
||||
results[_b]["ser_type"] = [i['labels'][i["scores"].index(max(i["scores"]))] for i in ser_res]
|
||||
|
||||
|
||||
|
||||
beg_idx = end_idx
|
||||
end_idx += 1
|
||||
max_len_in_batch = sample_length
|
||||
@ -594,6 +628,7 @@ class AutoModel:
|
||||
"end": vadsegment[1],
|
||||
"sentence": rest["text"],
|
||||
"timestamp": rest["timestamp"],
|
||||
"emotion": rest["ser_type"],
|
||||
}
|
||||
)
|
||||
elif self.spk_mode == "punc_segment":
|
||||
@ -617,6 +652,13 @@ class AutoModel:
|
||||
raw_text,
|
||||
return_raw_text=return_raw_text,
|
||||
)
|
||||
if "ser_type" in result:
|
||||
if len(sentence_list) == len(result["ser_type"]):
|
||||
for i in range(len(sentence_list)):
|
||||
sentence_list[i]["emotion"] = result["ser_type"][i]
|
||||
else:
|
||||
merged_list = [[x[0], x[1], y] for x, y in zip(all_segments, result["ser_type"])]
|
||||
distribute_emotion(sentence_list, merged_list)
|
||||
distribute_spk(sentence_list, sv_output)
|
||||
result["sentence_info"] = sentence_list
|
||||
elif kwargs.get("sentence_timestamp", False):
|
||||
@ -640,6 +682,8 @@ class AutoModel:
|
||||
result["sentence_info"] = sentence_list
|
||||
if "spk_embedding" in result:
|
||||
del result["spk_embedding"]
|
||||
if "ser_type" in result:
|
||||
del result["ser_type"]
|
||||
|
||||
result["key"] = key
|
||||
results_ret_list.append(result)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
@ -79,7 +80,10 @@ def download_from_ms(**kwargs):
|
||||
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
||||
if isinstance(kwargs, DictConfig):
|
||||
kwargs = OmegaConf.to_container(kwargs, resolve=True)
|
||||
if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
|
||||
logging.warning(f'trust_remote_code: {kwargs.get("trust_remote_code", False)}')
|
||||
if os.path.exists(os.path.join(model_or_path, "requirements.txt")) and kwargs.get(
|
||||
"trust_remote_code", False
|
||||
):
|
||||
requirements = os.path.join(model_or_path, "requirements.txt")
|
||||
print(f"Detect model requirements, begin to install it: {requirements}")
|
||||
from funasr.utils.install_model_requirements import install_requirements
|
||||
|
||||
@ -1 +1 @@
|
||||
1.2.6
|
||||
1.2.7
|
||||
Loading…
Reference in New Issue
Block a user