This commit is contained in:
嘉渊 2023-04-24 14:56:46 +08:00
parent d1408466aa
commit 80aeac6edc
2 changed files with 16 additions and 2 deletions

View File

@ -3,8 +3,8 @@
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1"
gpu_num=2
CUDA_VISIBLE_DEVICES="0"
gpu_num=1
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob

View File

@ -0,0 +1,14 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsNormalize(torch.nn.Module, ABC):
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
raise NotImplementedError