mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
export
This commit is contained in:
parent
e382250465
commit
795b6e0486
@ -151,12 +151,7 @@ class SANMVadEncoder(nn.Module):
|
||||
|
||||
def prepare_mask(self, mask, sub_masks):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
# sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32)
|
||||
if len(mask.shape) == 2:
|
||||
mask_4d_bhlt = 1 - sub_masks[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask_4d_bhlt = 1 - sub_masks[:, None, :]
|
||||
mask_4d_bhlt = mask_4d_bhlt * -10000.0
|
||||
mask_4d_bhlt = (1 - sub_masks) * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
|
||||
@ -63,11 +63,11 @@ class VadRealtimeTransformer(nn.Module):
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
|
||||
def get_input_names(self):
|
||||
return ['input', 'text_lengths', 'vad_mask']
|
||||
return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
@ -81,6 +81,10 @@ class VadRealtimeTransformer(nn.Module):
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'sub_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'logits': {
|
||||
1: 'logits_length'
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user