From 795b6e04864d7a8ea1cb8e41a412152651c47eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 30 Mar 2023 17:14:29 +0800 Subject: [PATCH] export --- funasr/export/models/encoder/sanm_encoder.py | 7 +------ funasr/export/models/vad_realtime_transformer.py | 10 +++++++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py index 8198d18a3..8390f6822 100644 --- a/funasr/export/models/encoder/sanm_encoder.py +++ b/funasr/export/models/encoder/sanm_encoder.py @@ -151,12 +151,7 @@ class SANMVadEncoder(nn.Module): def prepare_mask(self, mask, sub_masks): mask_3d_btd = mask[:, :, None] - # sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) - if len(mask.shape) == 2: - mask_4d_bhlt = 1 - sub_masks[:, None, None, :] - elif len(mask.shape) == 3: - mask_4d_bhlt = 1 - sub_masks[:, None, :] - mask_4d_bhlt = mask_4d_bhlt * -10000.0 + mask_4d_bhlt = (1 - sub_masks) * -10000.0 return mask_3d_btd, mask_4d_bhlt diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py index a3d486432..093e71de1 100644 --- a/funasr/export/models/vad_realtime_transformer.py +++ b/funasr/export/models/vad_realtime_transformer.py @@ -63,11 +63,11 @@ class VadRealtimeTransformer(nn.Module): text_lengths = torch.tensor([length], dtype=torch.int32) vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] sub_masks = torch.ones(length, length, dtype=torch.float32) - sub_masks = torch.tril(sub_masks) - return (text_indexes, text_lengths, vad_mask, sub_masks) + sub_masks = torch.tril(sub_masks).type(torch.float32) + return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) def get_input_names(self): - return ['input', 'text_lengths', 'vad_mask'] + return ['input', 'text_lengths', 'vad_mask', 'sub_masks'] def get_output_names(self): return ['logits'] @@ -81,6 +81,10 @@ class VadRealtimeTransformer(nn.Module): 2: 'feats_length1', 3: 'feats_length2' }, + 'sub_masks': { + 2: 'feats_length1', + 3: 'feats_length2' + }, 'logits': { 1: 'logits_length' },