from typing import Optional from typing import Tuple import torch class AbsProfileAug(torch.nn.Module): """Abstract class for the augmentation of profile The process-flow: Frontend --> SpecAug -> Normalization -> Encoder -> Decoder `-> ProfileAug -> Speaker Encoder --' """ def forward( self, x: torch.Tensor, x_lengths: torch.Tensor = None, profile: torch.Tensor = None, profile_lengths: torch.Tensor = None, binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: raise NotImplementedError