punc onnx

This commit is contained in:
游雁 2023-04-28 14:21:45 +08:00
parent f4710b4180
commit ea6903101b

View File

@ -53,7 +53,7 @@ class CT_Transformer(nn.Module):
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
@ -130,7 +130,7 @@ class CT_Transformer_VadRealtime(nn.Module):
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)