mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
TOLD/SOND: add support in build_diar_model.py
This commit is contained in:
parent
edd41f5a30
commit
db1495e24c
@ -3,7 +3,7 @@ import logging
|
||||
import torch
|
||||
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.label_aggregation import LabelAggregate
|
||||
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
|
||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||
@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
||||
from funasr.models.specaug.profileaug import ProfileAug
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
@ -52,6 +54,15 @@ specaug_choices = ClassChoices(
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
profileaug_choices = ClassChoices(
|
||||
name="profileaug",
|
||||
classes=dict(
|
||||
profileaug=ProfileAug,
|
||||
),
|
||||
type_check=AbsProfileAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
@ -64,7 +75,8 @@ normalize_choices = ClassChoices(
|
||||
label_aggregator_choices = ClassChoices(
|
||||
"label_aggregator",
|
||||
classes=dict(
|
||||
label_aggregator=LabelAggregate
|
||||
label_aggregator=LabelAggregate,
|
||||
label_aggregator_max_pool=LabelAggregateMaxPooling,
|
||||
),
|
||||
default=None,
|
||||
optional=True,
|
||||
@ -155,6 +167,8 @@ class_choices_list = [
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --profileaug and --profileaug_conf
|
||||
profileaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --label_aggregator and --label_aggregator_conf
|
||||
@ -217,6 +231,13 @@ def build_diar_model(args):
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# Data augmentation for Profiles
|
||||
if hasattr(args, "profileaug") and args.profileaug is not None:
|
||||
profileaug_class = profileaug_choices.get_class(args.profileaug)
|
||||
profileaug = profileaug_class(**args.profileaug_conf)
|
||||
else:
|
||||
profileaug = None
|
||||
|
||||
# normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
@ -261,6 +282,7 @@ def build_diar_model(args):
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
profileaug=profileaug,
|
||||
normalize=normalize,
|
||||
label_aggregator=label_aggregator,
|
||||
encoder=encoder,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user