TOLD/SOND: add support in build_diar_model.py

This commit is contained in:
志浩 2023-08-01 17:47:33 +08:00
parent edd41f5a30
commit db1495e24c

View File

@ -3,7 +3,7 @@ import logging
import torch
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.e2e_diar_sond import DiarSondModel
@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.models.specaug.profileaug import ProfileAug
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.torch_utils.initialize import initialize
@ -52,6 +54,15 @@ specaug_choices = ClassChoices(
default=None,
optional=True,
)
profileaug_choices = ClassChoices(
name="profileaug",
classes=dict(
profileaug=ProfileAug,
),
type_check=AbsProfileAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
@ -64,7 +75,8 @@ normalize_choices = ClassChoices(
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
label_aggregator=LabelAggregate
label_aggregator=LabelAggregate,
label_aggregator_max_pool=LabelAggregateMaxPooling,
),
default=None,
optional=True,
@ -155,6 +167,8 @@ class_choices_list = [
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --profileaug and --profileaug_conf
profileaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
@ -217,6 +231,13 @@ def build_diar_model(args):
else:
specaug = None
# Data augmentation for Profiles
if hasattr(args, "profileaug") and args.profileaug is not None:
profileaug_class = profileaug_choices.get_class(args.profileaug)
profileaug = profileaug_class(**args.profileaug_conf)
else:
profileaug = None
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
@ -261,6 +282,7 @@ def build_diar_model(args):
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
profileaug=profileaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,