update repo

This commit is contained in:
嘉渊 2023-05-08 19:12:28 +08:00
parent 300caf84db
commit e9cafb55ce
6 changed files with 17 additions and 17 deletions

View File

@ -19,7 +19,7 @@ from funasr.models.decoder.transformer_decoder import (
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
@ -76,7 +76,7 @@ normalize_choices = ClassChoices(
model_choices = ClassChoices(
"model",
classes=dict(
asr=ESPnetASRModel,
asr=ASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,

View File

@ -39,7 +39,7 @@ else:
yield
class ESPnetASRModel(FunASRModel):
class ASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@ -49,9 +49,7 @@ class ESPnetASRModel(FunASRModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -64,6 +62,8 @@ class ESPnetASRModel(FunASRModel):
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

View File

@ -51,7 +51,6 @@ class MFCCA(FunASRModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
@ -65,6 +64,7 @@ class MFCCA(FunASRModel):
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
preencoder: Optional[AbsPreEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

View File

@ -55,9 +55,7 @@ class Paraformer(FunASRModel):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -78,6 +76,8 @@ class Paraformer(FunASRModel):
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@ -732,9 +732,7 @@ class ParaformerBert(Paraformer):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -757,6 +755,8 @@ class ParaformerBert(Paraformer):
embeds_id: int = 2,
embeds_loss_weight: float = 0.0,
embed_dims: int = 768,
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@ -1008,9 +1008,7 @@ class BiCifParaformer(Paraformer):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -1030,6 +1028,8 @@ class BiCifParaformer(Paraformer):
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@ -1277,9 +1277,7 @@ class ContextualParaformer(Paraformer):
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -1309,6 +1307,8 @@ class ContextualParaformer(Paraformer):
bias_encoder_type: str = 'lstm',
label_bracket: bool = False,
use_decoder_embedding: bool = False,
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

View File

@ -38,7 +38,7 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@ -118,7 +118,7 @@ normalize_choices = ClassChoices(
model_choices = ClassChoices(
"model",
classes=dict(
asr=ESPnetASRModel,
asr=ASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,

View File

@ -21,7 +21,7 @@ from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder