From db1495e24c1a376ea340ce4dc3a269dbe588d392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 1 Aug 2023 17:47:33 +0800 Subject: [PATCH] TOLD/SOND: add support in build_diar_model.py --- funasr/build_utils/build_diar_model.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py index 0ea31270e..1aa0701f9 100644 --- a/funasr/build_utils/build_diar_model.py +++ b/funasr/build_utils/build_diar_model.py @@ -3,7 +3,7 @@ import logging import torch from funasr.layers.global_mvn import GlobalMVN -from funasr.layers.label_aggregation import LabelAggregate +from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling from funasr.layers.utterance_mvn import UtteranceMVN from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel from funasr.models.e2e_diar_sond import DiarSondModel @@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.models.frontend.windowing import SlidingWindow from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR +from funasr.models.specaug.abs_profileaug import AbsProfileAug +from funasr.models.specaug.profileaug import ProfileAug from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.torch_utils.initialize import initialize @@ -52,6 +54,15 @@ specaug_choices = ClassChoices( default=None, optional=True, ) +profileaug_choices = ClassChoices( + name="profileaug", + classes=dict( + profileaug=ProfileAug, + ), + type_check=AbsProfileAug, + default=None, + optional=True, +) normalize_choices = ClassChoices( "normalize", classes=dict( @@ -64,7 +75,8 @@ normalize_choices = ClassChoices( label_aggregator_choices = ClassChoices( "label_aggregator", classes=dict( - label_aggregator=LabelAggregate + label_aggregator=LabelAggregate, + label_aggregator_max_pool=LabelAggregateMaxPooling, ), default=None, optional=True, @@ -155,6 +167,8 @@ class_choices_list = [ frontend_choices, # --specaug and --specaug_conf specaug_choices, + # --profileaug and --profileaug_conf + profileaug_choices, # --normalize and --normalize_conf normalize_choices, # --label_aggregator and --label_aggregator_conf @@ -217,6 +231,13 @@ def build_diar_model(args): else: specaug = None + # Data augmentation for Profiles + if hasattr(args, "profileaug") and args.profileaug is not None: + profileaug_class = profileaug_choices.get_class(args.profileaug) + profileaug = profileaug_class(**args.profileaug_conf) + else: + profileaug = None + # normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) @@ -261,6 +282,7 @@ def build_diar_model(args): vocab_size=vocab_size, frontend=frontend, specaug=specaug, + profileaug=profileaug, normalize=normalize, label_aggregator=label_aggregator, encoder=encoder,