From 85b8628dbf3020e73580b73240804d587ead4eb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 30 Mar 2023 17:03:50 +0800 Subject: [PATCH] export --- funasr/export/models/encoder/sanm_encoder.py | 5 +++-- funasr/export/models/vad_realtime_transformer.py | 15 ++++++++++----- funasr/export/test/test_onnx_punc_vadrealtime.py | 6 +++++- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py index a4b112fb8..5437440a1 100644 --- a/funasr/export/models/encoder/sanm_encoder.py +++ b/funasr/export/models/encoder/sanm_encoder.py @@ -163,9 +163,9 @@ class SANMVadEncoder(nn.Module): self.num_heads = model.encoders[0].self_attn.h self.hidden_size = model.encoders[0].self_attn.linear_out.out_features - def prepare_mask(self, mask): + def prepare_mask(self, mask, sub_masks): mask_3d_btd = mask[:, :, None] - sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) + # sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) if len(mask.shape) == 2: mask_4d_bhlt = 1 - sub_masks[:, None, None, :] elif len(mask.shape) == 3: @@ -178,6 +178,7 @@ class SANMVadEncoder(nn.Module): speech: torch.Tensor, speech_lengths: torch.Tensor, vad_mask: torch.Tensor, + sub_masks: torch.Tensor, ): speech = speech * self._output_size ** 0.5 mask = self.make_pad_mask(speech_lengths) diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py index de7c721eb..a3d486432 100644 --- a/funasr/export/models/vad_realtime_transformer.py +++ b/funasr/export/models/vad_realtime_transformer.py @@ -11,7 +11,7 @@ from funasr.punctuation.abs_model import AbsPunctuation from funasr.punctuation.sanm_encoder import SANMVadEncoder from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export -class VadRealtimeTransformer(AbsPunctuation): +class VadRealtimeTransformer(nn.Module): def __init__( self, @@ -36,8 +36,11 @@ class VadRealtimeTransformer(AbsPunctuation): - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, - vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: + def forward(self, input: torch.Tensor, + text_lengths: torch.Tensor, + vad_indexes: torch.Tensor, + sub_masks: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: """Compute loss value from buffer sequences. Args: @@ -47,7 +50,7 @@ class VadRealtimeTransformer(AbsPunctuation): """ x = self.embed(input) # mask = self._target_mask(input) - h, _ = self.encoder(x, text_lengths, vad_indexes) + h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks) y = self.decoder(h) return y @@ -59,7 +62,9 @@ class VadRealtimeTransformer(AbsPunctuation): text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) text_lengths = torch.tensor([length], dtype=torch.int32) vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] - return (text_indexes, text_lengths, vad_mask) + sub_masks = torch.ones(length, length, dtype=torch.float32) + sub_masks = torch.tril(sub_masks) + return (text_indexes, text_lengths, vad_mask, sub_masks) def get_input_names(self): return ['input', 'text_lengths', 'vad_mask'] diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py index c5cc17ea1..6544a898f 100644 --- a/funasr/export/test/test_onnx_punc_vadrealtime.py +++ b/funasr/export/test/test_onnx_punc_vadrealtime.py @@ -9,7 +9,11 @@ if __name__ == '__main__': output_name = [nd.name for nd in sess.get_outputs()] def _get_feed_dict(text_length): - return {'input': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32), 'vad_mask': np.ones((1, 1, text_length, text_length), dtype=np.float32)} + return {'input': np.ones((1, text_length), dtype=np.int64), + 'text_lengths': np.array([text_length,], dtype=np.int32), + 'vad_mask': np.ones((1, 1, text_length, text_length), dtype=np.float32), + 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32)) + } def _run(feed_dict): output = sess.run(output_name, input_feed=feed_dict)