This commit is contained in:
游雁 2023-04-07 14:08:34 +08:00
parent 8b802ea8a0
commit eb82674d88
2 changed files with 31 additions and 30 deletions

View File

@ -158,13 +158,14 @@ class SANMVadEncoder(nn.Module):
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
vad_mask: torch.Tensor,
vad_masks: torch.Tensor,
sub_masks: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
vad_masks = self.prepare_mask(mask, vad_masks)
mask = self.prepare_mask(mask, sub_masks)
vad_mask = self.prepare_mask(mask, vad_mask)
if self.embed is None:
xs_pad = speech
else:
@ -176,7 +177,7 @@ class SANMVadEncoder(nn.Module):
# encoder_outs = self.model.encoders(xs_pad, mask)
for layer_idx, encoder_layer in enumerate(self.model.encoders):
if layer_idx == len(self.model.encoders) - 1:
mask = vad_mask
mask = vad_masks
encoder_outs = encoder_layer(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
@ -187,26 +188,26 @@ class SANMVadEncoder(nn.Module):
def get_output_size(self):
return self.model.encoders[0].size
def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return (feats)
def get_input_names(self):
return ['feats']
def get_output_names(self):
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
def get_dynamic_axes(self):
return {
'feats': {
1: 'feats_length'
},
'encoder_out': {
1: 'enc_out_length'
},
'predictor_weight': {
1: 'pre_out_length'
}
}
# def get_dummy_inputs(self):
# feats = torch.randn(1, 100, self.feats_dim)
# return (feats)
#
# def get_input_names(self):
# return ['feats']
#
# def get_output_names(self):
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
#
# def get_dynamic_axes(self):
# return {
# 'feats': {
# 1: 'feats_length'
# },
# 'encoder_out': {
# 1: 'enc_out_length'
# },
# 'predictor_weight': {
# 1: 'pre_out_length'
# }
#
# }

View File

@ -66,13 +66,13 @@ class VadRealtimeTransformer(nn.Module):
length = 10
text_indexes = torch.tensor([[266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757]], dtype=torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :]
vad_masks = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
return (text_indexes, text_lengths, vad_masks, sub_masks[None, None, :, :])
def get_input_names(self):
return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
return ['input', 'text_lengths', 'vad_masks', 'sub_masks']
def get_output_names(self):
return ['logits']
@ -82,7 +82,7 @@ class VadRealtimeTransformer(nn.Module):
'input': {
1: 'feats_length'
},
'vad_mask': {
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},