From b3fb4c0acd5f52a313f024b6f69b8f025c6eddfe Mon Sep 17 00:00:00 2001 From: ming030890 <67713085+ming030890@users.noreply.github.com> Date: Tue, 5 Aug 2025 10:48:10 +0100 Subject: [PATCH] Allow one to set a custom progress callback (#2609) * Allow one to set a custom progress callback so that they can show it own progrss bar * Uncomment an existing test * restore indentation --------- Co-authored-by: Tony Mak --- funasr/auto/auto_model.py | 26 ++++++++++++++++++++++---- tests/test_auto_model.py | 36 ++++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 10d2ef6c0..a864dadd7 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -301,14 +301,27 @@ class AutoModel: res = self.model(*args, kwargs) return res - def generate(self, input, input_len=None, **cfg): + def generate(self, input, input_len=None, progress_callback=None, **cfg): if self.vad_model is None: - return self.inference(input, input_len=input_len, **cfg) + return self.inference( + input, input_len=input_len, progress_callback=progress_callback, **cfg + ) else: - return self.inference_with_vad(input, input_len=input_len, **cfg) + return self.inference_with_vad( + input, input_len=input_len, progress_callback=progress_callback, **cfg + ) - def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg): + def inference( + self, + input, + input_len=None, + model=None, + kwargs=None, + key=None, + progress_callback=None, + **cfg, + ): kwargs = self.kwargs if kwargs is None else kwargs if "cache" in kwargs: kwargs.pop("cache") @@ -365,6 +378,11 @@ class AutoModel: if pbar: pbar.update(end_idx - beg_idx) pbar.set_description(description) + if progress_callback: + try: + progress_callback(end_idx, num_samples) + except Exception as e: + logging.error(f"progress_callback error: {e}") time_speech_total += batch_data_time time_escape_total += time_escape diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py index 932376b1c..d17d9ab41 100644 --- a/tests/test_auto_model.py +++ b/tests/test_auto_model.py @@ -22,7 +22,39 @@ class TestAutoModel(unittest.TestCase): kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}} model = AutoModel(**kwargs) self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr) - # res = model.generate(input="/test.wav", + # res = model.generate(input="/test.wav", # batch_size_s=300) + + def test_progress_callback_called(self): + class DummyModel: + def __init__(self): + self.param = torch.nn.Parameter(torch.zeros(1)) + + def parameters(self): + return iter([self.param]) + + def eval(self): + pass + + def inference(self, data_in=None, **kwargs): + results = [{"text": str(d)} for d in data_in] + return results, {"batch_data_time": 1} + + am = AutoModel.__new__(AutoModel) + am.model = DummyModel() + am.kwargs = {"batch_size": 2, "disable_pbar": True} + + progress = [] + + res = AutoModel.inference( + am, + ["a", "b", "c"], + progress_callback=lambda idx, total: progress.append((idx, total)), + ) + + self.assertEqual(len(progress), 2) + self.assertEqual(progress, [(2, 3), (3, 3)]) + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()