From 3388361d3b4b3123a02ab8db254e3641bb9c9fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=8C=E5=8C=80?= Date: Tue, 28 Feb 2023 14:33:40 +0800 Subject: [PATCH] update vad inference --- funasr/bin/vad_inference.py | 8 +++++--- funasr/models/e2e_vad.py | 5 ++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py index 258b38b33..aaa38b3fe 100644 --- a/funasr/bin/vad_inference.py +++ b/funasr/bin/vad_inference.py @@ -86,7 +86,8 @@ class Speech2VadSegment: @torch.no_grad() def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, + in_cache: Dict[str, torch.Tensor] = dict() ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]: """Inference @@ -125,11 +126,12 @@ class Speech2VadSegment: batch = { "feats": feats[:, t_offset:t_offset + step, :], "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], - "is_final": is_final + "is_final": is_final, + "in_cache": in_cache } # a. To device batch = to_device(batch, device=self.device) - segments_part = self.vad_model(**batch) + segments_part, in_cache = self.vad_model(**batch) if segments_part: for batch_num in range(0, self.batch_size): segments[batch_num] += segments_part[batch_num] diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py index c21be1b95..b9be89aaa 100755 --- a/funasr/models/e2e_vad.py +++ b/funasr/models/e2e_vad.py @@ -444,7 +444,7 @@ class E2EVadModel(nn.Module): def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False - ) -> List[List[List[int]]]: + ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: self.waveform = waveform # compute decibel for each frame self.ComputeDecibel() self.ComputeScores(feats, in_cache) @@ -468,8 +468,7 @@ class E2EVadModel(nn.Module): if is_final: # reset class variables and clear the dict for the next query self.AllResetDetection() - in_cache.clear() - return segments + return segments, in_cache def DetectCommonFrames(self) -> int: if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: