update data2vec pretrain

This commit is contained in:
jmwang66 2023-02-06 16:09:24 +08:00
parent 933d5afc02
commit 7bea618623
3 changed files with 581 additions and 0 deletions

45
funasr/bin/data2vec_train.py Executable file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env python3
import os
from funasr.tasks.data2vec import Data2VecTask
def parse_args():
parser = Data2VecTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for data2vec Training
Data2VecTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

160
funasr/models/data2vec.py Normal file
View File

@ -0,0 +1,160 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class Data2VecPretrainModel(AbsESPnetModel):
"""Data2Vec Pretrain model"""
def __init__(
self,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
):
assert check_argument_types()
super().__init__()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.encoder = encoder
self.num_updates = 0
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
), (speech.shape, speech_lengths.shape)
self.encoder.set_num_updates(self.num_updates)
# 1. Encoder
encoder_out = self.encode(speech, speech_lengths)
losses = encoder_out["losses"]
loss = sum(losses.values())
sample_size = encoder_out["sample_size"]
loss = loss.sum() / sample_size
target_var = float(encoder_out["target_var"])
pred_var = float(encoder_out["pred_var"])
ema_decay = float(encoder_out["ema_decay"])
stats = dict(
loss=torch.clone(loss.detach()),
target_var=target_var,
pred_var=pred_var,
ema_decay=ema_decay,
)
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
speech_lengths = None
encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
return encoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def get_num_updates(self):
return self.num_updates

376
funasr/tasks/data2vec.py Normal file
View File

@ -0,0 +1,376 @@
import argparse
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.data2vec import Data2VecPretrainModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(specaug=SpecAug),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
data2vec_encoder=Data2VecEncoder,
),
type_check=AbsEncoder,
default="data2vec_encoder",
)
model_choices = ClassChoices(
"model",
classes=dict(
data2vec=Data2VecPretrainModel,
),
default="data2vec",
)
class Data2VecTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --model and --model_conf
model_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default=None,
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized " "in the specified level token",
)
group.add_argument(
"--feats_type",
type=str,
default='fbank',
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
parser.add_argument(
"--pred_masked_weight",
type=float,
default=1.0,
help="weight for predictive loss for masked frames",
)
parser.add_argument(
"--pred_nomask_weight",
type=float,
default=0.0,
help="weight for predictive loss for unmasked frames",
)
parser.add_argument(
"--loss_weights",
type=float,
default=0.0,
help="weights for additional loss terms (not first one)",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(clipping=True)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
# for pre-training
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, "preencoder", None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 5. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(
input_size=input_size,
**args.encoder_conf,
)
# 6. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("data2vec")
model = model_class(
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
)
# 7. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model