diff --git a/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py index dfcb8e649..e0ac08ced 100644 --- a/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py +++ b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py @@ -2,8 +2,9 @@ from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks inference_diar_pipline = pipeline( - task=Tasks.speaker_diarization, + task=Tasks.auto_speech_recognition, model='damo/speech_diarization_eend-ola-en-us-callhome-8k', model_revision="v1.0.0", ) -results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav"]) \ No newline at end of file +results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record2.wav"]) +print(results) \ No newline at end of file diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 1fae766ea..0ab6b1ad3 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -234,6 +234,9 @@ def inference_launch(**kwargs): elif mode == "rnnt": from funasr.bin.asr_inference_rnnt import inference_modelscope return inference_modelscope(**kwargs) + elif mode == "eend-ola": + from funasr.bin.eend_ola_inference import inference_modelscope + return inference_modelscope(mode=mode, **kwargs) else: logging.info("Unknown decoding mode: {}".format(mode)) return None diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py index b35824aaa..048327856 100755 --- a/funasr/bin/eend_ola_inference.py +++ b/funasr/bin/eend_ola_inference.py @@ -16,8 +16,8 @@ from typing import Union import numpy as np import torch -from typeguard import check_argument_types from scipy.signal import medfilt +from typeguard import check_argument_types from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.tasks.diar import EENDOLADiarTask @@ -28,6 +28,7 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none + class Speech2Diarization: """Speech2Diarlization class @@ -237,7 +238,7 @@ def inference_modelscope( results = speech2diar(**batch) # post process - a = results[0].cpu().numpy() + a = results[0][0].cpu().numpy() a = medfilt(a, (11, 1)) rst = [] for spkid, frames in enumerate(a.T): @@ -246,8 +247,8 @@ def inference_modelscope( fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} {:s} " for s, e in zip(changes[::2], changes[1::2]): st = s / 10. - ed = e / 10. - rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid)))) + dur = (e - s) / 10. + rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid)))) # Only supporting batch_size==1 value = "\n".join(rst)