From db308e75357ae686f5103123f157a7f79887a103 Mon Sep 17 00:00:00 2001 From: Kun Lu <71560661+CSLukkun@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:52:10 +0800 Subject: [PATCH] feat: add campplus merge_thr (#2135) --- funasr/auto/auto_model.py | 3 ++- funasr/models/campplus/cluster_backend.py | 4 ++-- tests/test_auto_model.py | 28 +++++++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) create mode 100644 tests/test_auto_model.py diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 71f44b46e..08308a244 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -147,13 +147,14 @@ class AutoModel: # if spk_model is not None, build spk model else None spk_model = kwargs.get("spk_model", None) spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) + cb_kwargs = {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {}) if spk_model is not None: logging.info("Building SPK model.") spk_kwargs["model"] = spk_model spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") spk_kwargs["device"] = kwargs["device"] spk_model, spk_kwargs = self.build_model(**spk_kwargs) - self.cb_model = ClusterBackend().to(kwargs["device"]) + self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) spk_mode = kwargs.get("spk_mode", "punc_segment") if spk_mode not in ["default", "vad_segment", "punc_segment"]: logging.error("spk_mode should be one of default, vad_segment and punc_segment.") diff --git a/funasr/models/campplus/cluster_backend.py b/funasr/models/campplus/cluster_backend.py index b98721eb8..9e7042153 100644 --- a/funasr/models/campplus/cluster_backend.py +++ b/funasr/models/campplus/cluster_backend.py @@ -139,9 +139,9 @@ class ClusterBackend(torch.nn.Module): model_config: The model config. """ - def __init__(self): + def __init__(self, merge_thr=0.78): super().__init__() - self.model_config = {"merge_thr": 0.78} + self.model_config = {"merge_thr": merge_thr} # self.other_config = kwargs self.spectral_cluster = SpectralCluster() diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py new file mode 100644 index 000000000..932376b1c --- /dev/null +++ b/tests/test_auto_model.py @@ -0,0 +1,28 @@ +import unittest +import torch +import numpy as np +from funasr.auto.auto_model import AutoModel + +class TestAutoModel(unittest.TestCase): + + def setUp(self): + self.base_kwargs = { + "model": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + "vad_model": "fsmn-vad", + "punc_model":"ct-punc", + "device": "cpu", + "batch_size": 1, + "disable_update": True, + } + + def test_merge_thr_in_cb_model(self): + kwargs = self.base_kwargs.copy() + kwargs["spk_model"] = "cam++" + merge_thr = 0.5 + kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}} + model = AutoModel(**kwargs) + self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr) + # res = model.generate(input="/test.wav", + # batch_size_s=300) +if __name__ == '__main__': + unittest.main() \ No newline at end of file