mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
onnx
This commit is contained in:
parent
8b802ea8a0
commit
eb82674d88
@ -158,13 +158,14 @@ class SANMVadEncoder(nn.Module):
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
vad_mask: torch.Tensor,
|
||||
vad_masks: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
vad_masks = self.prepare_mask(mask, vad_masks)
|
||||
mask = self.prepare_mask(mask, sub_masks)
|
||||
vad_mask = self.prepare_mask(mask, vad_mask)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = speech
|
||||
else:
|
||||
@ -176,7 +177,7 @@ class SANMVadEncoder(nn.Module):
|
||||
# encoder_outs = self.model.encoders(xs_pad, mask)
|
||||
for layer_idx, encoder_layer in enumerate(self.model.encoders):
|
||||
if layer_idx == len(self.model.encoders) - 1:
|
||||
mask = vad_mask
|
||||
mask = vad_masks
|
||||
encoder_outs = encoder_layer(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
@ -187,26 +188,26 @@ class SANMVadEncoder(nn.Module):
|
||||
def get_output_size(self):
|
||||
return self.model.encoders[0].size
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
feats = torch.randn(1, 100, self.feats_dim)
|
||||
return (feats)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['feats']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'feats': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'encoder_out': {
|
||||
1: 'enc_out_length'
|
||||
},
|
||||
'predictor_weight': {
|
||||
1: 'pre_out_length'
|
||||
}
|
||||
|
||||
}
|
||||
# def get_dummy_inputs(self):
|
||||
# feats = torch.randn(1, 100, self.feats_dim)
|
||||
# return (feats)
|
||||
#
|
||||
# def get_input_names(self):
|
||||
# return ['feats']
|
||||
#
|
||||
# def get_output_names(self):
|
||||
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
#
|
||||
# def get_dynamic_axes(self):
|
||||
# return {
|
||||
# 'feats': {
|
||||
# 1: 'feats_length'
|
||||
# },
|
||||
# 'encoder_out': {
|
||||
# 1: 'enc_out_length'
|
||||
# },
|
||||
# 'predictor_weight': {
|
||||
# 1: 'pre_out_length'
|
||||
# }
|
||||
#
|
||||
# }
|
||||
|
||||
@ -66,13 +66,13 @@ class VadRealtimeTransformer(nn.Module):
|
||||
length = 10
|
||||
text_indexes = torch.tensor([[266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757]], dtype=torch.int32)
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :]
|
||||
vad_masks = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
return (text_indexes, text_lengths, vad_masks, sub_masks[None, None, :, :])
|
||||
|
||||
def get_input_names(self):
|
||||
return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
|
||||
return ['input', 'text_lengths', 'vad_masks', 'sub_masks']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
@ -82,7 +82,7 @@ class VadRealtimeTransformer(nn.Module):
|
||||
'input': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'vad_mask': {
|
||||
'vad_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user