FunASR/egs/alimeeting/diarization/sond/unit_test.py
speech_asr 2ba4683eb2 update
2023-03-16 11:14:42 +08:00

98 lines
2.8 KiB
Python

from funasr.bin.diar_inference_launch import inference_launch
import os
def test_fbank_cpu_infer():
diar_config_path = "config_fbank.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_fbank_gpu_infer():
diar_config_path = "config_fbank.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_wav_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_wav.scp", "speech", "sound"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_without_profile_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
raw_inputs = [[
"data/unit_test/raw_inputs/record.wav",
"data/unit_test/raw_inputs/spk1.wav",
"data/unit_test/raw_inputs/spk2.wav",
"data/unit_test/raw_inputs/spk3.wav",
"data/unit_test/raw_inputs/spk4.wav"
]]
pipeline = inference_launch(
mode="sond_demo",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
param_dict={},
)
results = pipeline(raw_inputs=raw_inputs)
print(results)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
test_fbank_cpu_infer()
test_fbank_gpu_infer()
test_wav_gpu_infer()
test_without_profile_gpu_infer()