From 5e59904fd49ff1fcb0d6869d297e05a59707bf58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 28 Mar 2023 20:34:53 +0800 Subject: [PATCH] export --- funasr/export/models/e2e_vad.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py index b4236e0b5..d3e8f30e5 100644 --- a/funasr/export/models/e2e_vad.py +++ b/funasr/export/models/e2e_vad.py @@ -24,19 +24,10 @@ class E2EVadModel(nn.Module): raise "unsupported encoder" - def forward(self, feats: torch.Tensor, - in_cache0: torch.Tensor, - in_cache1: torch.Tensor, - in_cache2: torch.Tensor, - in_cache3: torch.Tensor, - ): + def forward(self, feats: torch.Tensor, *args, ): - scores, (cache0, cache1, cache2, cache3) = self.encoder(feats, - in_cache0, - in_cache1, - in_cache2, - in_cache3) # return B * T * D - return scores, cache0, cache1, cache2, cache3 + scores, out_caches = self.encoder(feats, *args) + return scores, out_caches def get_dummy_inputs(self, frame=30): speech = torch.randn(1, frame, self.feats_dim)