From fee3fd32fd2af544da864a20a73bb26e16aa2217 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 14 Mar 2023 10:58:25 +0800 Subject: [PATCH 1/2] add unit_test for speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch --- .../unit_test.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py new file mode 100644 index 000000000..3cb31cfb7 --- /dev/null +++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py @@ -0,0 +1,26 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +# 初始化推理 pipeline +# 当以原始音频作为输入时使用配置文件 sond.yaml,并设置 mode 为sond_demo +inference_diar_pipline = pipeline( + mode="sond_demo", + num_workers=0, + task=Tasks.speaker_diarization, + diar_model_config="sond.yaml", + model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch', + sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch", + sv_model_revision="master", +) + +# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音 +audio_list = [[ + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav", + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav", + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav", + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav" +]] + +results = inference_diar_pipline(audio_in=audio_list) +for rst in results: + print(rst["value"]) From f191f4c868af7ca73e4fae6242339309ae15d88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 14 Mar 2023 14:31:27 +0800 Subject: [PATCH 2/2] modify statistic pooling layer --- funasr/models/pooling/statistic_pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py index 97f8a24f5..8f85de99d 100644 --- a/funasr/models/pooling/statistic_pooling.py +++ b/funasr/models/pooling/statistic_pooling.py @@ -82,7 +82,7 @@ def windowed_statistic_pooling( tt = xs_pad.shape[2] num_chunk = int(math.ceil(tt / pooling_stride)) pad = pooling_size // 2 - if xs_pad.shape == 4: + if len(xs_pad.shape) == 4: features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") else: features = F.pad(xs_pad, (pad, pad), "reflect")