update ola

This commit is contained in:
speech_asr 2023-03-13 18:45:27 +08:00
parent 229efa6250
commit e27de5aa6b
3 changed files with 335 additions and 27 deletions

View File

@ -11,7 +11,8 @@ import torch
import torch.nn as nn
from typeguard import check_argument_types
from funasr.modules.eend_ola.encoder import TransformerEncoder
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
@ -34,12 +35,13 @@ def pad_attractor(att, max_n_speakers):
class DiarEENDOLAModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
"""EEND-OLA diarization model"""
def __init__(
self,
encoder: TransformerEncoder,
eda: EncoderDecoderAttractor,
frontend: WavFrontendMel23,
encoder: EENDOLATransformerEncoder,
encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
max_n_speaker: int = 8,
attractor_loss_weight: float = 1.0,
@ -49,8 +51,9 @@ class DiarEENDOLAModel(AbsESPnetModel):
assert check_argument_types()
super().__init__()
self.frontend = frontend
self.encoder = encoder
self.eda = eda
self.encoder_decoder_attractor = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
@ -187,16 +190,18 @@ class DiarEENDOLAModel(AbsESPnetModel):
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
if self.frontend is not None:
speech = self.frontend(speech)
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
attractors, probs = self.eda.estimate(
attractors, probs = self.encoder_decoder_attractor.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
attractors, probs = self.eda.estimate(emb)
attractors, probs = self.encoder_decoder_attractor.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:

View File

@ -1,5 +1,5 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module):
return self.dropout(x)
class TransformerEncoder(nn.Module):
def __init__(self, idim, n_layers, n_units,
e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
super(TransformerEncoder, self).__init__()
class EENDOLATransformerEncoder(nn.Module):
def __init__(self,
idim: int,
n_layers: int,
n_units: int,
e_units: int = 2048,
h: int = 8,
dropout_rate: float = 0.1,
use_pos_emb: bool = False):
super(EENDOLATransformerEncoder, self).__init__()
self.lnorm_in = nn.LayerNorm(n_units)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout_rate)

View File

@ -20,19 +20,18 @@ from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@ -41,17 +40,13 @@ from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
@ -70,6 +65,7 @@ frontend_choices = ClassChoices(
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_mel23=WavFrontendMel23,
),
type_check=AbsFrontend,
default="default",
@ -126,6 +122,7 @@ encoder_choices = ClassChoices(
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
eend_ola_transformer=EENDOLATransformerEncoder,
),
type_check=torch.nn.Module,
default="resnet34",
@ -177,6 +174,15 @@ decoder_choices = ClassChoices(
type_check=torch.nn.Module,
default="fsmn",
)
# encoder_decoder_attractor is used for EEND-OLA
encoder_decoder_attractor_choices = ClassChoices(
"encoder_decoder_attractor",
classes=dict(
eda=EncoderDecoderAttractor,
),
type_check=torch.nn.Module,
default="eda",
)
class DiarTask(AbsTask):
@ -594,3 +600,294 @@ class DiarTask(AbsTask):
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
class EENDOLADiarTask(AbsTask):
# If you need more than 1 optimizer, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --speaker_encoder and --speaker_encoder_conf
encoder_decoder_attractor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="char",
choices=["char"],
help="The text will be tokenized in the specified level token",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=None,
non_linguistic_symbols=None,
text_cleaner=None,
g2p_type=None,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "profile", "binary_labels")
else:
# Recognition mode
retval = ("speech")
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
# 1. frontend
if args.input_size is None or args.frontend == "wav_frontend_mel23":
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 3. EncoderDecoderAttractor
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
# 9. Build model
model_class = model_choices.get_class(args.model)
model = model_class(
frontend=frontend,
encoder=encoder,
encoder_decoder_attractor=encoder_decoder_attractor,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
checkpoint = torch.load(model_file, map_location=device)
if "state_dict" in checkpoint.keys():
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
model.to(device)
return model, args