This commit is contained in:
speech_asr 2023-04-11 00:27:54 +08:00
parent 0a954637cb
commit dfa356a10c
4 changed files with 15 additions and 41 deletions

View File

@ -14,14 +14,8 @@ import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.base_model import FunASRModel
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@ -43,9 +37,9 @@ class DiarSondModel(FunASRModel):
def __init__(
self,
vocab_size: int,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
ci_scorer: torch.nn.Module,
@ -348,7 +342,7 @@ class DiarSondModel(FunASRModel):
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
if isinstance(self.ci_scorer, AbsEncoder):
if isinstance(self.ci_scorer, torch.nn.Module):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:

View File

@ -10,21 +10,10 @@ from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@ -43,11 +32,11 @@ class ESPnetSVModel(FunASRModel):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,

View File

@ -2,17 +2,12 @@ import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
@ -37,8 +32,8 @@ class TimestampPredictor(FunASRModel):
def __init__(
self,
frontend: Optional[AbsFrontend],
encoder: AbsEncoder,
frontend: Optional[torch.nn.Module],
encoder: torch.nn.Module,
predictor: CifPredictorV3,
predictor_bias: int = 0,
token_list=None,

View File

@ -18,15 +18,11 @@ from funasr.losses.label_smoothing_loss import (
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.base_model import FunASRModel
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
from funasr.models.predictor.cif import mae_loss
@ -48,11 +44,11 @@ class UniASR(FunASRModel):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,