From d72a4497a57f3b415753a7e7c5d4b8d367cf951f Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Thu, 11 Jan 2024 19:16:51 +0800 Subject: [PATCH] support oracle num for asr with spk --- funasr/bin/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index cf29d91da..515170b5f 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -137,6 +137,9 @@ class AutoModel: if spk_mode not in ["default", "vad_segment", "punc_segment"]: logging.error("spk_mode should be one of default, vad_segment and punc_segment.") self.spk_mode = spk_mode + self.preset_spk_num = kwargs.get("preset_spk_num", None) + if self.preset_spk_num: + logging.warning("Using preset speaker number: {}".format(self.preset_spk_num)) logging.warning("Many to print when using speaker model...") self.kwargs = kwargs @@ -397,7 +400,7 @@ class AutoModel: if self.spk_model is not None: all_segments = sorted(all_segments, key=lambda x: x[0]) spk_embedding = result['spk_embedding'] - labels = self.cb_model(spk_embedding) + labels = self.cb_model(spk_embedding, oracle_num=self.preset_spk_num) del result['spk_embedding'] sv_output = postprocess(all_segments, None, labels, spk_embedding) if self.spk_mode == 'vad_segment':