diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py new file mode 100755 index 000000000..b9dbdff10 --- /dev/null +++ b/funasr/bin/data2vec_train.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +import os + +from funasr.tasks.data2vec import Data2VecTask + + +def parse_args(): + parser = Data2VecTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for data2vec Training + Data2VecTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py new file mode 100644 index 000000000..fcd6bd2cf --- /dev/null +++ b/funasr/models/data2vec.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from funasr.layers.abs_normalize import AbsNormalize +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class Data2VecPretrainModel(AbsESPnetModel): + """Data2Vec Pretrain model""" + + def __init__( + self, + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + ): + assert check_argument_types() + + super().__init__() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.encoder = encoder + self.num_updates = 0 + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + ), (speech.shape, speech_lengths.shape) + + self.encoder.set_num_updates(self.num_updates) + + # 1. Encoder + encoder_out = self.encode(speech, speech_lengths) + + losses = encoder_out["losses"] + loss = sum(losses.values()) + sample_size = encoder_out["sample_size"] + loss = loss.sum() / sample_size + + target_var = float(encoder_out["target_var"]) + pred_var = float(encoder_out["pred_var"]) + ema_decay = float(encoder_out["ema_decay"]) + + stats = dict( + loss=torch.clone(loss.detach()), + target_var=target_var, + pred_var=pred_var, + ema_decay=ema_decay, + ) + + loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor + ) -> Dict[str, torch.Tensor]: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + """Frontend + Encoder. + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None + speech_lengths = None + encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False) + + return encoder_out + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + + def get_num_updates(self): + return self.num_updates diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py new file mode 100644 index 000000000..9a64e1f58 --- /dev/null +++ b/funasr/tasks/data2vec.py @@ -0,0 +1,376 @@ +import argparse +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.datasets.collate_fn import CommonCollateFn +from funasr.datasets.preprocessor import CommonPreprocessor +from funasr.layers.abs_normalize import AbsNormalize +from funasr.layers.global_mvn import GlobalMVN +from funasr.layers.utterance_mvn import UtteranceMVN +from funasr.models.data2vec import Data2VecPretrainModel +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.preencoder.sinc import LightweightSincConvs +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.specaug.specaug import SpecAug +from funasr.tasks.abs_task import AbsTask +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.torch_utils.initialize import initialize +from funasr.train.class_choices import ClassChoices +from funasr.train.trainer import Trainer +from funasr.utils.types import float_or_none +from funasr.utils.types import int_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict(specaug=SpecAug), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default=None, + optional=True, +) +preencoder_choices = ClassChoices( + name="preencoder", + classes=dict( + sinc=LightweightSincConvs, + ), + type_check=AbsPreEncoder, + default=None, + optional=True, +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + data2vec_encoder=Data2VecEncoder, + ), + type_check=AbsEncoder, + default="data2vec_encoder", +) +model_choices = ClassChoices( + "model", + classes=dict( + data2vec=Data2VecPretrainModel, + ), + default="data2vec", +) + + +class Data2VecTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --encoder and --encoder_conf + encoder_choices, + # --model and --model_conf + model_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default=None, + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + + group.add_argument( + "--feats_type", + type=str, + default='fbank', + help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)", + ) + + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="Specify g2p method if --token_type=phn", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + parser.add_argument( + "--pred_masked_weight", + type=float, + default=1.0, + help="weight for predictive loss for masked frames", + ) + parser.add_argument( + "--pred_nomask_weight", + type=float, + default=0.0, + help="weight for predictive loss for unmasked frames", + ) + parser.add_argument( + "--loss_weights", + type=float, + default=0.0, + help="weights for additional loss terms (not first one)", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + return CommonCollateFn(clipping=True) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + # for pre-training + retval = ("speech",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, "preencoder", None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 5. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class( + input_size=input_size, + **args.encoder_conf, + ) + + # 6. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("data2vec") + model = model_class( + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + ) + + # 7. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model