diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py index 118e24008..a4b112fb8 100644 --- a/funasr/export/models/encoder/sanm_encoder.py +++ b/funasr/export/models/encoder/sanm_encoder.py @@ -165,7 +165,7 @@ class SANMVadEncoder(nn.Module): def prepare_mask(self, mask): mask_3d_btd = mask[:, :, None] - sub_masks = subsequent_mask(mask.size(-1)) + sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) if len(mask.shape) == 2: mask_4d_bhlt = 1 - sub_masks[:, None, None, :] elif len(mask.shape) == 3: diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py index 381d02d35..de7c721eb 100644 --- a/funasr/export/models/vad_realtime_transformer.py +++ b/funasr/export/models/vad_realtime_transformer.py @@ -32,6 +32,7 @@ class VadRealtimeTransformer(AbsPunctuation): assert False, "Only support samn encode." # self.encoder = model.encoder self.decoder = model.decoder + self.model_name = model_name @@ -46,7 +47,7 @@ class VadRealtimeTransformer(AbsPunctuation): """ x = self.embed(input) # mask = self._target_mask(input) - h, _, _ = self.encoder(x, text_lengths, vad_indexes) + h, _ = self.encoder(x, text_lengths, vad_indexes) y = self.decoder(h) return y @@ -57,7 +58,7 @@ class VadRealtimeTransformer(AbsPunctuation): length = 120 text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) text_lengths = torch.tensor([length], dtype=torch.int32) - vad_mask = torch.ones(length, length)[None, None, :, :] + vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] return (text_indexes, text_lengths, vad_mask) def get_input_names(self): diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py new file mode 100644 index 000000000..c5cc17ea1 --- /dev/null +++ b/funasr/export/test/test_onnx_punc_vadrealtime.py @@ -0,0 +1,18 @@ +import onnxruntime +import numpy as np + + +if __name__ == '__main__': + onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx" + sess = onnxruntime.InferenceSession(onnx_path) + input_name = [nd.name for nd in sess.get_inputs()] + output_name = [nd.name for nd in sess.get_outputs()] + + def _get_feed_dict(text_length): + return {'input': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32), 'vad_mask': np.ones((1, 1, text_length, text_length), dtype=np.float32)} + + def _run(feed_dict): + output = sess.run(output_name, input_feed=feed_dict) + for name, value in zip(output_name, output): + print('{}: {}'.format(name, value)) + _run(_get_feed_dict(10))