From e27de5aa6bd9af2a82e80604978b50aa538493ec Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 18:45:27 +0800 Subject: [PATCH] update ola --- funasr/models/e2e_diar_eend_ola.py | 19 +- funasr/modules/eend_ola/encoder.py | 16 +- funasr/tasks/diar.py | 327 +++++++++++++++++++++++++++-- 3 files changed, 335 insertions(+), 27 deletions(-) diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 2960b23ca..f589269c5 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -11,7 +11,8 @@ import torch import torch.nn as nn from typeguard import check_argument_types -from funasr.modules.eend_ola.encoder import TransformerEncoder +from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict from funasr.torch_utils.device_funcs import force_gatherable @@ -34,12 +35,13 @@ def pad_attractor(att, max_n_speakers): class DiarEENDOLAModel(AbsESPnetModel): - """CTC-attention hybrid Encoder-Decoder model""" + """EEND-OLA diarization model""" def __init__( self, - encoder: TransformerEncoder, - eda: EncoderDecoderAttractor, + frontend: WavFrontendMel23, + encoder: EENDOLATransformerEncoder, + encoder_decoder_attractor: EncoderDecoderAttractor, n_units: int = 256, max_n_speaker: int = 8, attractor_loss_weight: float = 1.0, @@ -49,8 +51,9 @@ class DiarEENDOLAModel(AbsESPnetModel): assert check_argument_types() super().__init__() + self.frontend = frontend self.encoder = encoder - self.eda = eda + self.encoder_decoder_attractor = encoder_decoder_attractor self.attractor_loss_weight = attractor_loss_weight self.max_n_speaker = max_n_speaker if mapping_dict is None: @@ -187,16 +190,18 @@ class DiarEENDOLAModel(AbsESPnetModel): shuffle: bool = True, threshold: float = 0.5, **kwargs): + if self.frontend is not None: + speech = self.frontend(speech) speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] emb = self.forward_encoder(speech, speech_lengths) if shuffle: orders = [np.arange(e.shape[0]) for e in emb] for order in orders: np.random.shuffle(order) - attractors, probs = self.eda.estimate( + attractors, probs = self.encoder_decoder_attractor.estimate( [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) else: - attractors, probs = self.eda.estimate(emb) + attractors, probs = self.encoder_decoder_attractor.estimate(emb) attractors_active = [] for p, att, e in zip(probs, attractors, emb): if n_speakers and n_speakers >= 0: diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py index 17d11ace7..4999031b1 100644 --- a/funasr/modules/eend_ola/encoder.py +++ b/funasr/modules/eend_ola/encoder.py @@ -1,5 +1,5 @@ import math -import numpy as np + import torch import torch.nn.functional as F from torch import nn @@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module): return self.dropout(x) -class TransformerEncoder(nn.Module): - def __init__(self, idim, n_layers, n_units, - e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False): - super(TransformerEncoder, self).__init__() +class EENDOLATransformerEncoder(nn.Module): + def __init__(self, + idim: int, + n_layers: int, + n_units: int, + e_units: int = 2048, + h: int = 8, + dropout_rate: float = 0.1, + use_pos_emb: bool = False): + super(EENDOLATransformerEncoder, self).__init__() self.lnorm_in = nn.LayerNorm(n_units) self.n_layers = n_layers self.dropout = nn.Dropout(dropout_rate) diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index e699dccb0..953ab82c8 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -20,19 +20,18 @@ from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.preprocessor import CommonPreprocessor from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN -from funasr.layers.utterance_mvn import UtteranceMVN from funasr.layers.label_aggregation import LabelAggregate -from funasr.models.ctc import CTC -from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar -from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN -from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder -from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder -from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder -from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.layers.utterance_mvn import UtteranceMVN from funasr.models.e2e_diar_sond import DiarSondModel from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN +from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder +from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder +from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.transformer_encoder import TransformerEncoder @@ -41,17 +40,13 @@ from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.fused import FusedFrontends from funasr.models.frontend.s3prl import S3prlFrontend from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.models.frontend.windowing import SlidingWindow -from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) -from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.preencoder.linear import LinearProjection -from funasr.models.preencoder.sinc import LightweightSincConvs from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder +from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.tasks.abs_task import AbsTask from funasr.torch_utils.initialize import initialize from funasr.train.abs_espnet_model import AbsESPnetModel @@ -70,6 +65,7 @@ frontend_choices = ClassChoices( s3prl=S3prlFrontend, fused=FusedFrontends, wav_frontend=WavFrontend, + wav_frontend_mel23=WavFrontendMel23, ), type_check=AbsFrontend, default="default", @@ -126,6 +122,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, ecapa_tdnn=ECAPA_TDNN, + eend_ola_transformer=EENDOLATransformerEncoder, ), type_check=torch.nn.Module, default="resnet34", @@ -177,6 +174,15 @@ decoder_choices = ClassChoices( type_check=torch.nn.Module, default="fsmn", ) +# encoder_decoder_attractor is used for EEND-OLA +encoder_decoder_attractor_choices = ClassChoices( + "encoder_decoder_attractor", + classes=dict( + eda=EncoderDecoderAttractor, + ), + type_check=torch.nn.Module, + default="eda", +) class DiarTask(AbsTask): @@ -594,3 +600,294 @@ class DiarTask(AbsTask): var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update + + +class EENDOLADiarTask(AbsTask): + # If you need more than 1 optimizer, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + model_choices, + # --encoder and --encoder_conf + encoder_choices, + # --speaker_encoder and --speaker_encoder_conf + encoder_decoder_attractor_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="char", + choices=["char"], + help="The text will be tokenized in the specified level token", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=None, + non_linguistic_symbols=None, + text_cleaner=None, + g2p_type=None, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "profile", "binary_labels") + else: + # Recognition mode + retval = ("speech") + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + + # 1. frontend + if args.input_size is None or args.frontend == "wav_frontend_mel23": + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 3. EncoderDecoderAttractor + encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) + encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) + + # 9. Build model + model_class = model_choices.get_class(args.model) + model = model_class( + frontend=frontend, + encoder=encoder, + encoder_decoder_attractor=encoder_decoder_attractor, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + device: str = "cpu", + ): + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + cmvn_file: The cmvn file for front-end + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + if model_file is not None: + if device == "cuda": + device = f"cuda:{torch.cuda.current_device()}" + checkpoint = torch.load(model_file, map_location=device) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + model.to(device) + return model, args