Compare commits

...

3 Commits

Author SHA1 Message Date
Yu Cao
fda7e74134
Merge e5cc659f40 into b3fb4c0acd 2025-08-06 10:21:40 +08:00
ming030890
b3fb4c0acd
Allow one to set a custom progress callback (#2609)
* Allow one to set a custom progress callback

so that they can show it own progrss bar

* Uncomment an existing test

* restore indentation

---------

Co-authored-by: Tony Mak <tony@Tonys-MacBook-Air-1802.local>
2025-08-05 17:48:10 +08:00
Yu Cao
e5cc659f40
fix "can not find model issue when running libtorch runtime" 2025-05-06 11:59:35 +08:00
3 changed files with 57 additions and 7 deletions

View File

@ -301,14 +301,27 @@ class AutoModel:
res = self.model(*args, kwargs)
return res
def generate(self, input, input_len=None, **cfg):
def generate(self, input, input_len=None, progress_callback=None, **cfg):
if self.vad_model is None:
return self.inference(input, input_len=input_len, **cfg)
return self.inference(
input, input_len=input_len, progress_callback=progress_callback, **cfg
)
else:
return self.inference_with_vad(input, input_len=input_len, **cfg)
return self.inference_with_vad(
input, input_len=input_len, progress_callback=progress_callback, **cfg
)
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
def inference(
self,
input,
input_len=None,
model=None,
kwargs=None,
key=None,
progress_callback=None,
**cfg,
):
kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
@ -365,6 +378,11 @@ class AutoModel:
if pbar:
pbar.update(end_idx - beg_idx)
pbar.set_description(description)
if progress_callback:
try:
progress_callback(end_idx, num_samples)
except Exception as e:
logging.error(f"progress_callback error: {e}")
time_speech_total += batch_data_time
time_escape_total += time_escape

View File

@ -84,4 +84,4 @@ def export_dynamic_axes(self):
def export_name(
self,
):
return "model"
return "model.onnx"

View File

@ -22,7 +22,39 @@ class TestAutoModel(unittest.TestCase):
kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}}
model = AutoModel(**kwargs)
self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
# res = model.generate(input="/test.wav",
# res = model.generate(input="/test.wav",
# batch_size_s=300)
def test_progress_callback_called(self):
class DummyModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.zeros(1))
def parameters(self):
return iter([self.param])
def eval(self):
pass
def inference(self, data_in=None, **kwargs):
results = [{"text": str(d)} for d in data_in]
return results, {"batch_data_time": 1}
am = AutoModel.__new__(AutoModel)
am.model = DummyModel()
am.kwargs = {"batch_size": 2, "disable_pbar": True}
progress = []
res = AutoModel.inference(
am,
["a", "b", "c"],
progress_callback=lambda idx, total: progress.append((idx, total)),
)
self.assertEqual(len(progress), 2)
self.assertEqual(progress, [(2, 3), (3, 3)])
if __name__ == '__main__':
unittest.main()
unittest.main()