diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
new file mode 100644
index 000000000..6d9cd3024
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
@@ -0,0 +1,24 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained ModelScope Model
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - audio_in: # support wav, url, bytes, and parsed audio format.
+ - output_dir: # If the input format is wav.scp, it needs to be set.
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+
+Modify inference related parameters in vad.yaml.
+
+- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
+- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
+ - The value tends to -1, the greater probability of noise being judged as speech
+ - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
new file mode 100644
index 000000000..c255474b8
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
@@ -0,0 +1,15 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
+ output_dir = None
+ inference_pipline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ model_revision=None,
+ output_dir=output_dir,
+ batch_size=1,
+ )
+ segments_result = inference_pipline(audio_in=audio_in)
+ print(segments_result)
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
new file mode 100644
index 000000000..6d9cd3024
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
@@ -0,0 +1,24 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained ModelScope Model
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - audio_in: # support wav, url, bytes, and parsed audio format.
+ - output_dir: # If the input format is wav.scp, it needs to be set.
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+
+Modify inference related parameters in vad.yaml.
+
+- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
+- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
+ - The value tends to -1, the greater probability of noise being judged as speech
+ - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
new file mode 100644
index 000000000..6061413e5
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
@@ -0,0 +1,15 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example_8k.wav'
+ output_dir = None
+ inference_pipline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model="damo/speech_fsmn_vad_zh-cn-8k-common",
+ model_revision='v1.1.1',
+ output_dir='./output_dir',
+ batch_size=1,
+ )
+ segments_result = inference_pipline(audio_in=audio_in)
+ print(segments_result)
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 1cdb582e5..b0f8a77b3 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@ class Speech2VadSegment:
self.device = device
self.dtype = dtype
self.frontend = frontend
+ self.batch_size = batch_size
@torch.no_grad()
def __call__(
@@ -110,10 +111,9 @@ class Speech2VadSegment:
# segments = self.vad_model(**batch)
# b. Forward Encoder sreaming
- segments = []
- segments_tmp = []
- step = 6000
t_offset = 0
+ step = min(feats_len, 6000)
+ segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
@@ -129,8 +129,8 @@ class Speech2VadSegment:
batch = to_device(batch, device=self.device)
segments_part = self.vad_model(**batch)
if segments_part:
- segments_tmp += segments_part[0]
- segments.append(segments_tmp)
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
return segments
@@ -254,7 +254,6 @@ def inference_modelscope(
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# do vad segment
results = speech2vadsegment(**batch)
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 8afc8db6d..b64c677f3 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@ class WindowDetector(object):
class E2EVadModel(nn.Module):
- def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
+ def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -227,7 +227,6 @@ class E2EVadModel(nn.Module):
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.streaming = streaming
self.ResetDetection()
def AllResetDetection(self):
@@ -451,11 +450,7 @@ class E2EVadModel(nn.Module):
if not is_final_send:
self.DetectCommonFrames()
else:
- if self.streaming:
- self.DetectLastFrames()
- else:
- self.AllResetDetection()
- self.DetectAllFrames() # offline decode and is_final_send == True
+ self.DetectLastFrames()
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
@@ -468,7 +463,8 @@ class E2EVadModel(nn.Module):
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
-
+ if is_final_send:
+ self.AllResetDetection()
return segments
def DetectCommonFrames(self) -> int:
@@ -494,18 +490,6 @@ class E2EVadModel(nn.Module):
return 0
- def DetectAllFrames(self) -> int:
- if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
- return 0
- if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
- frame_state = FrameState.kFrameStateInvalid
- for t in range(0, self.frm_cnt):
- frame_state = self.GetFrameState(t)
- self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1)
- else:
- pass
- return 0
-
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index e2a912394..22a5cb3d3 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -291,8 +291,7 @@ class VADTask(AbsTask):
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
- model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf,
- streaming=args.encoder_conf.get('streaming', False))
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
return model