diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py index 79e93a863..b35824aaa 100755 --- a/funasr/bin/eend_ola_inference.py +++ b/funasr/bin/eend_ola_inference.py @@ -237,7 +237,8 @@ def inference_modelscope( results = speech2diar(**batch) # post process - a = medfilt(results[0], (11, 1)) + a = results[0].cpu().numpy() + a = medfilt(a, (11, 1)) rst = [] for spkid, frames in enumerate(a.T): frames = np.pad(frames, (1, 1), 'constant')