FunASR/funasr/models/specaug/abs_profileaug.py
2023-08-01 17:03:39 +08:00

23 lines
670 B
Python

from typing import Optional
from typing import Tuple
import torch
class AbsProfileAug(torch.nn.Module):
"""Abstract class for the augmentation of profile
The process-flow:
Frontend --> SpecAug -> Normalization -> Encoder -> Decoder
`-> ProfileAug -> Speaker Encoder --'
"""
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor = None,
profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError