mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fixbug for sd and sv
This commit is contained in:
parent
267dddcdbb
commit
3b42ace3d4
67
egs/alimeeting/diarization/sond/unit_test_modelscope.py
Normal file
67
egs/alimeeting/diarization/sond/unit_test_modelscope.py
Normal file
@ -0,0 +1,67 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def test_wav_cpu_infer():
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
"data/unit_test/test_wav.scp,speech,sound",
|
||||
"data/unit_test/test_profile.scp,profile,kaldi_ark",
|
||||
]
|
||||
diar_pipeline = pipeline(
|
||||
task=Tasks.speaker_diarization,
|
||||
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
||||
mode="sond",
|
||||
output_dir=output_dir,
|
||||
num_workers=0,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = diar_pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_wav_gpu_infer():
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
"data/unit_test/test_wav.scp,speech,sound",
|
||||
"data/unit_test/test_profile.scp,profile,kaldi_ark",
|
||||
]
|
||||
diar_pipeline = pipeline(
|
||||
task=Tasks.speaker_diarization,
|
||||
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
||||
mode="sond",
|
||||
output_dir=output_dir,
|
||||
num_workers=0,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = diar_pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_without_profile_gpu_infer():
|
||||
raw_inputs = [
|
||||
"data/unit_test/raw_inputs/record.wav",
|
||||
"data/unit_test/raw_inputs/spk1.wav",
|
||||
"data/unit_test/raw_inputs/spk2.wav",
|
||||
"data/unit_test/raw_inputs/spk3.wav",
|
||||
"data/unit_test/raw_inputs/spk4.wav"
|
||||
]
|
||||
diar_pipeline = pipeline(
|
||||
task=Tasks.speaker_diarization,
|
||||
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
||||
mode="sond_demo",
|
||||
num_workers=0,
|
||||
log_level="WARNING",
|
||||
param_dict={},
|
||||
)
|
||||
results = diar_pipeline(raw_inputs=raw_inputs)
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
test_wav_cpu_infer()
|
||||
test_wav_gpu_infer()
|
||||
test_without_profile_gpu_infer()
|
||||
@ -312,7 +312,7 @@ def inference_modelscope(
|
||||
|
||||
def _forward(
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
|
||||
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
|
||||
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
param_dict: Optional[dict] = None,
|
||||
):
|
||||
@ -321,6 +321,8 @@ def inference_modelscope(
|
||||
if isinstance(raw_inputs, (list, tuple)):
|
||||
assert all([len(example) >= 2 for example in raw_inputs]), \
|
||||
"The length of test case in raw_inputs must larger than 1 (>=2)."
|
||||
if not isinstance(raw_inputs, List):
|
||||
raw_inputs = [raw_inputs]
|
||||
|
||||
def prepare_dataset():
|
||||
for idx, example in enumerate(raw_inputs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user