mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Compare commits
3 Commits
eb5b01c265
...
fda7e74134
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fda7e74134 | ||
|
|
b3fb4c0acd | ||
|
|
e5cc659f40 |
@ -301,14 +301,27 @@ class AutoModel:
|
||||
res = self.model(*args, kwargs)
|
||||
return res
|
||||
|
||||
def generate(self, input, input_len=None, **cfg):
|
||||
def generate(self, input, input_len=None, progress_callback=None, **cfg):
|
||||
if self.vad_model is None:
|
||||
return self.inference(input, input_len=input_len, **cfg)
|
||||
return self.inference(
|
||||
input, input_len=input_len, progress_callback=progress_callback, **cfg
|
||||
)
|
||||
|
||||
else:
|
||||
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
||||
return self.inference_with_vad(
|
||||
input, input_len=input_len, progress_callback=progress_callback, **cfg
|
||||
)
|
||||
|
||||
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
|
||||
def inference(
|
||||
self,
|
||||
input,
|
||||
input_len=None,
|
||||
model=None,
|
||||
kwargs=None,
|
||||
key=None,
|
||||
progress_callback=None,
|
||||
**cfg,
|
||||
):
|
||||
kwargs = self.kwargs if kwargs is None else kwargs
|
||||
if "cache" in kwargs:
|
||||
kwargs.pop("cache")
|
||||
@ -365,6 +378,11 @@ class AutoModel:
|
||||
if pbar:
|
||||
pbar.update(end_idx - beg_idx)
|
||||
pbar.set_description(description)
|
||||
if progress_callback:
|
||||
try:
|
||||
progress_callback(end_idx, num_samples)
|
||||
except Exception as e:
|
||||
logging.error(f"progress_callback error: {e}")
|
||||
time_speech_total += batch_data_time
|
||||
time_escape_total += time_escape
|
||||
|
||||
|
||||
@ -84,4 +84,4 @@ def export_dynamic_axes(self):
|
||||
def export_name(
|
||||
self,
|
||||
):
|
||||
return "model"
|
||||
return "model.onnx"
|
||||
|
||||
@ -22,7 +22,39 @@ class TestAutoModel(unittest.TestCase):
|
||||
kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}}
|
||||
model = AutoModel(**kwargs)
|
||||
self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
|
||||
# res = model.generate(input="/test.wav",
|
||||
# res = model.generate(input="/test.wav",
|
||||
# batch_size_s=300)
|
||||
|
||||
def test_progress_callback_called(self):
|
||||
class DummyModel:
|
||||
def __init__(self):
|
||||
self.param = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def parameters(self):
|
||||
return iter([self.param])
|
||||
|
||||
def eval(self):
|
||||
pass
|
||||
|
||||
def inference(self, data_in=None, **kwargs):
|
||||
results = [{"text": str(d)} for d in data_in]
|
||||
return results, {"batch_data_time": 1}
|
||||
|
||||
am = AutoModel.__new__(AutoModel)
|
||||
am.model = DummyModel()
|
||||
am.kwargs = {"batch_size": 2, "disable_pbar": True}
|
||||
|
||||
progress = []
|
||||
|
||||
res = AutoModel.inference(
|
||||
am,
|
||||
["a", "b", "c"],
|
||||
progress_callback=lambda idx, total: progress.append((idx, total)),
|
||||
)
|
||||
|
||||
self.assertEqual(len(progress), 2)
|
||||
self.assertEqual(progress, [(2, 3), (3, 3)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user