Allow one to set a custom progress callback (#2609)

* Allow one to set a custom progress callback

so that they can show it own progrss bar

* Uncomment an existing test

* restore indentation

---------

Co-authored-by: Tony Mak <tony@Tonys-MacBook-Air-1802.local>
This commit is contained in:
ming030890 2025-08-05 10:48:10 +01:00 committed by GitHub
parent 8316fc4197
commit b3fb4c0acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 6 deletions

View File

@ -301,14 +301,27 @@ class AutoModel:
res = self.model(*args, kwargs) res = self.model(*args, kwargs)
return res return res
def generate(self, input, input_len=None, **cfg): def generate(self, input, input_len=None, progress_callback=None, **cfg):
if self.vad_model is None: if self.vad_model is None:
return self.inference(input, input_len=input_len, **cfg) return self.inference(
input, input_len=input_len, progress_callback=progress_callback, **cfg
)
else: else:
return self.inference_with_vad(input, input_len=input_len, **cfg) return self.inference_with_vad(
input, input_len=input_len, progress_callback=progress_callback, **cfg
)
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg): def inference(
self,
input,
input_len=None,
model=None,
kwargs=None,
key=None,
progress_callback=None,
**cfg,
):
kwargs = self.kwargs if kwargs is None else kwargs kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs: if "cache" in kwargs:
kwargs.pop("cache") kwargs.pop("cache")
@ -365,6 +378,11 @@ class AutoModel:
if pbar: if pbar:
pbar.update(end_idx - beg_idx) pbar.update(end_idx - beg_idx)
pbar.set_description(description) pbar.set_description(description)
if progress_callback:
try:
progress_callback(end_idx, num_samples)
except Exception as e:
logging.error(f"progress_callback error: {e}")
time_speech_total += batch_data_time time_speech_total += batch_data_time
time_escape_total += time_escape time_escape_total += time_escape

View File

@ -24,5 +24,37 @@ class TestAutoModel(unittest.TestCase):
self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr) self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
# res = model.generate(input="/test.wav", # res = model.generate(input="/test.wav",
# batch_size_s=300) # batch_size_s=300)
def test_progress_callback_called(self):
class DummyModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.zeros(1))
def parameters(self):
return iter([self.param])
def eval(self):
pass
def inference(self, data_in=None, **kwargs):
results = [{"text": str(d)} for d in data_in]
return results, {"batch_data_time": 1}
am = AutoModel.__new__(AutoModel)
am.model = DummyModel()
am.kwargs = {"batch_size": 2, "disable_pbar": True}
progress = []
res = AutoModel.inference(
am,
["a", "b", "c"],
progress_callback=lambda idx, total: progress.append((idx, total)),
)
self.assertEqual(len(progress), 2)
self.assertEqual(progress, [(2, 3), (3, 3)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()