diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py new file mode 100755 index 000000000..d65895f30 --- /dev/null +++ b/funasr/bin/eend_ola_inference.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types + +from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.tasks.diar import EENDOLADiarTask +from funasr.torch_utils.device_funcs import to_device +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none + + +class Speech2Diarization: + """Speech2Diarlization class + + Examples: + >>> import soundfile + >>> import numpy as np + >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth") + >>> profile = np.load("profiles.npy") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2diar(audio, profile) + {"spk1": [(int, int), ...], ...} + + """ + + def __init__( + self, + diar_train_config: Union[Path, str] = None, + diar_model_file: Union[Path, str] = None, + device: str = "cpu", + dtype: str = "float32", + ): + assert check_argument_types() + + # 1. Build Diarization model + diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file( + config_file=diar_train_config, + model_file=diar_model_file, + device=device + ) + frontend = None + if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None: + frontend = WavFrontendMel23(**diar_train_args.frontend_conf) + + # set up seed for eda + np.random.seed(diar_train_args.seed) + torch.manual_seed(diar_train_args.seed) + torch.cuda.manual_seed(diar_train_args.seed) + os.environ['PYTORCH_SEED'] = str(diar_train_args.seed) + logging.info("diar_model: {}".format(diar_model)) + logging.info("diar_train_args: {}".format(diar_train_args)) + diar_model.to(dtype=getattr(torch, dtype)).eval() + + self.diar_model = diar_model + self.diar_train_args = diar_train_args + self.device = device + self.dtype = dtype + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, + speech: Union[torch.Tensor, np.ndarray], + speech_lengths: Union[torch.Tensor, np.ndarray] = None + ): + """Inference + + Args: + speech: Input speech data + Returns: + diarization results + + """ + assert check_argument_types() + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.diar_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + batch = {"speech": feats, "speech_lengths": feats_len} + batch = to_device(batch, device=self.device) + results = self.diar_model.estimate_sequential(**batch) + + return results + + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ): + """Build Speech2Diarization instance from the pretrained model. + + Args: + model_tag (Optional[str]): Model tag of the pretrained models. + Currently, the tags of espnet_model_zoo are supported. + + Returns: + Speech2Diarization: Speech2Diarization instance. + + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Diarization(**kwargs) + + +def inference_modelscope( + diar_train_config: str, + diar_model_file: str, + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + num_workers: int = 0, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + param_dict: Optional[dict] = None, + **kwargs, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("param_dict: {}".format(param_dict)) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Build speech2diar + speech2diar_kwargs = dict( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + device=device, + dtype=dtype, + streaming=streaming, + ) + logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) + speech2diar = Speech2Diarization.from_pretrained( + model_tag=model_tag, + **speech2diar_kwargs, + ) + speech2diar.diar_model.eval() + + def output_results_str(results: dict, uttid: str): + rst = [] + mid = uttid.rsplit("-", 1)[0] + for key in results: + results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]] + template = "SPEAKER {} 0 {:.2f} {:.2f} {} " + for spk, segs in results.items(): + rst.extend([template.format(mid, st, ed, spk) for st, ed in segs]) + + return "\n".join(rst) + + def _forward( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, + raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None, + output_dir_v2: Optional[str] = None, + param_dict: Optional[dict] = None, + ): + # 2. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = EENDOLADiarTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False), + collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 3. Start for-loop + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + output_writer = open("{}/result.txt".format(output_path), "w") + result_list = [] + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + results = speech2diar(**batch) + # Only supporting batch_size==1 + key, value = keys[0], output_results_str(results, keys[0]) + item = {"key": key, "value": value} + result_list.append(item) + if output_path is not None: + output_writer.write(value) + output_writer.flush() + + if output_path is not None: + output_writer.close() + + return result_list + + return _forward + + +def inference( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + diar_train_config: Optional[str], + diar_model_file: Optional[str], + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + seed: int = 0, + num_workers: int = 1, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + smooth_size: int = 83, + dur_threshold: int = 10, + out_format: str = "vad", + **kwargs, +): + inference_pipeline = inference_modelscope( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + output_dir=output_dir, + batch_size=batch_size, + dtype=dtype, + ngpu=ngpu, + seed=seed, + num_workers=num_workers, + log_level=log_level, + key_file=key_file, + model_tag=model_tag, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + smooth_size=smooth_size, + dur_threshold=dur_threshold, + out_format=out_format, + **kwargs, + ) + + return inference_pipeline(data_path_and_name_and_type, raw_inputs=None) + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speaker verification/x-vector extraction", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=False) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--diar_train_config", + type=str, + help="diarization training configuration", + ) + group.add_argument( + "--diar_model_file", + type=str, + help="diarization model parameter file", + ) + group.add_argument( + "--dur_threshold", + type=int, + default=10, + help="The threshold for short segments in number frames" + ) + parser.add_argument( + "--smooth_size", + type=int, + default=83, + help="The smoothing window length in number frames" + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument("--streaming", type=str2bool, default=False) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + logging.info("args: {}".format(kwargs)) + if args.output_dir is None: + jobid, n_gpu = 1, 1 + gpuid = args.gpuid_list.split(",")[jobid - 1] + else: + jobid = int(args.output_dir.split(".")[-1]) + n_gpu = len(args.gpuid_list.split(",")) + gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + results_list = inference(**kwargs) + for results in results_list: + print("{} {}".format(results["key"], results["value"])) + + +if __name__ == "__main__": + main() diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py new file mode 100644 index 000000000..f589269c5 --- /dev/null +++ b/funasr/models/e2e_diar_eend_ola.py @@ -0,0 +1,242 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from typeguard import check_argument_types + +from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder +from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor +from funasr.modules.eend_ola.utils.power import generate_mapping_dict +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + pass +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +def pad_attractor(att, max_n_speakers): + C, D = att.shape + if C < max_n_speakers: + att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0) + return att + + +class DiarEENDOLAModel(AbsESPnetModel): + """EEND-OLA diarization model""" + + def __init__( + self, + frontend: WavFrontendMel23, + encoder: EENDOLATransformerEncoder, + encoder_decoder_attractor: EncoderDecoderAttractor, + n_units: int = 256, + max_n_speaker: int = 8, + attractor_loss_weight: float = 1.0, + mapping_dict=None, + **kwargs, + ): + assert check_argument_types() + + super().__init__() + self.frontend = frontend + self.encoder = encoder + self.encoder_decoder_attractor = encoder_decoder_attractor + self.attractor_loss_weight = attractor_loss_weight + self.max_n_speaker = max_n_speaker + if mapping_dict is None: + mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) + self.mapping_dict = mapping_dict + # PostNet + self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) + self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) + + def forward_encoder(self, xs, ilens): + xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) + pad_shape = xs.shape + xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens] + xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2) + emb = self.encoder(xs, xs_mask) + emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0) + emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)] + return emb + + def forward_post_net(self, logits, ilens): + maxlen = torch.max(ilens).to(torch.int).item() + logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) + logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) + outputs, (_, _) = self.PostNet(logits) + outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] + outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] + outputs = [self.output_layer(output) for output in outputs] + return outputs + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + + # Collect total loss stats + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def estimate_sequential(self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + n_speakers: int = None, + shuffle: bool = True, + threshold: float = 0.5, + **kwargs): + if self.frontend is not None: + speech = self.frontend(speech) + speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] + emb = self.forward_encoder(speech, speech_lengths) + if shuffle: + orders = [np.arange(e.shape[0]) for e in emb] + for order in orders: + np.random.shuffle(order) + attractors, probs = self.encoder_decoder_attractor.estimate( + [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) + else: + attractors, probs = self.encoder_decoder_attractor.estimate(emb) + attractors_active = [] + for p, att, e in zip(probs, attractors, emb): + if n_speakers and n_speakers >= 0: + att = att[:n_speakers, ] + attractors_active.append(att) + elif threshold is not None: + silence = torch.nonzero(p < threshold)[0] + n_spk = silence[0] if silence.size else None + att = att[:n_spk, ] + attractors_active.append(att) + else: + NotImplementedError('n_speakers or threshold has to be given.') + raw_n_speakers = [att.shape[0] for att in attractors_active] + attractors = [ + pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] + for att in attractors_active] + ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] + logits = self.forward_post_net(ys, speech_lengths) + ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in + zip(logits, raw_n_speakers)] + + return ys, emb, attractors, raw_n_speakers + + def recover_y_from_powerlabel(self, logit, n_speaker): + pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) + oov_index = torch.where(pred == self.mapping_dict['oov'])[0] + for i in oov_index: + if i > 0: + pred[i] = pred[i - 1] + else: + pred[i] = 0 + pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] + decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] + decisions = torch.from_numpy( + np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( + torch.float32) + decisions = decisions[:, :n_speaker] + return decisions diff --git a/funasr/models/frontend/eend_ola_feature.py b/funasr/models/frontend/eend_ola_feature.py new file mode 100644 index 000000000..e15b71c25 --- /dev/null +++ b/funasr/models/frontend/eend_ola_feature.py @@ -0,0 +1,51 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This module is for computing audio features + +import librosa +import numpy as np + + +def transform(Y, dtype=np.float32): + Y = np.abs(Y) + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + return Y.astype(dtype) + + +def subsample(Y, T, subsampling=1): + Y_ss = Y[::subsampling] + T_ss = T[::subsampling] + return Y_ss, T_ss + + +def splice(Y, context_size=0): + Y_pad = np.pad( + Y, + [(context_size, context_size), (0, 0)], + 'constant') + Y_spliced = np.lib.stride_tricks.as_strided( + np.ascontiguousarray(Y_pad), + (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), + (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) + return Y_spliced + + +def stft( + data, + frame_size=1024, + frame_shift=256): + fft_size = 1 << (frame_size - 1).bit_length() + if len(data) % frame_shift == 0: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T[:-1] + else: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T \ No newline at end of file diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py index 17d11ace7..4999031b1 100644 --- a/funasr/modules/eend_ola/encoder.py +++ b/funasr/modules/eend_ola/encoder.py @@ -1,5 +1,5 @@ import math -import numpy as np + import torch import torch.nn.functional as F from torch import nn @@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module): return self.dropout(x) -class TransformerEncoder(nn.Module): - def __init__(self, idim, n_layers, n_units, - e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False): - super(TransformerEncoder, self).__init__() +class EENDOLATransformerEncoder(nn.Module): + def __init__(self, + idim: int, + n_layers: int, + n_units: int, + e_units: int = 2048, + h: int = 8, + dropout_rate: float = 0.1, + use_pos_emb: bool = False): + super(EENDOLATransformerEncoder, self).__init__() self.lnorm_in = nn.LayerNorm(n_units) self.n_layers = n_layers self.dropout = nn.Dropout(dropout_rate) diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index e699dccb0..ae7ee9b40 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -20,19 +20,19 @@ from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.preprocessor import CommonPreprocessor from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN -from funasr.layers.utterance_mvn import UtteranceMVN from funasr.layers.label_aggregation import LabelAggregate -from funasr.models.ctc import CTC -from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar -from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN -from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder -from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder -from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder -from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.layers.utterance_mvn import UtteranceMVN from funasr.models.e2e_diar_sond import DiarSondModel +from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN +from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder +from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder +from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.transformer_encoder import TransformerEncoder @@ -41,17 +41,13 @@ from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.fused import FusedFrontends from funasr.models.frontend.s3prl import S3prlFrontend from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.models.frontend.windowing import SlidingWindow -from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) -from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.preencoder.linear import LinearProjection -from funasr.models.preencoder.sinc import LightweightSincConvs from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder +from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.tasks.abs_task import AbsTask from funasr.torch_utils.initialize import initialize from funasr.train.abs_espnet_model import AbsESPnetModel @@ -70,6 +66,7 @@ frontend_choices = ClassChoices( s3prl=S3prlFrontend, fused=FusedFrontends, wav_frontend=WavFrontend, + wav_frontend_mel23=WavFrontendMel23, ), type_check=AbsFrontend, default="default", @@ -107,6 +104,7 @@ model_choices = ClassChoices( "model", classes=dict( sond=DiarSondModel, + eend_ola=DiarEENDOLAModel, ), type_check=AbsESPnetModel, default="sond", @@ -126,6 +124,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, ecapa_tdnn=ECAPA_TDNN, + eend_ola_transformer=EENDOLATransformerEncoder, ), type_check=torch.nn.Module, default="resnet34", @@ -177,6 +176,15 @@ decoder_choices = ClassChoices( type_check=torch.nn.Module, default="fsmn", ) +# encoder_decoder_attractor is used for EEND-OLA +encoder_decoder_attractor_choices = ClassChoices( + "encoder_decoder_attractor", + classes=dict( + eda=EncoderDecoderAttractor, + ), + type_check=torch.nn.Module, + default="eda", +) class DiarTask(AbsTask): @@ -594,3 +602,294 @@ class DiarTask(AbsTask): var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update + + +class EENDOLADiarTask(AbsTask): + # If you need more than 1 optimizer, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + model_choices, + # --encoder and --encoder_conf + encoder_choices, + # --speaker_encoder and --speaker_encoder_conf + encoder_decoder_attractor_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="char", + choices=["char"], + help="The text will be tokenized in the specified level token", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=None, + non_linguistic_symbols=None, + text_cleaner=None, + g2p_type=None, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "profile", "binary_labels") + else: + # Recognition mode + retval = ("speech") + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + + # 1. frontend + if args.input_size is None or args.frontend == "wav_frontend_mel23": + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 3. EncoderDecoderAttractor + encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) + encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) + + # 9. Build model + model_class = model_choices.get_class(args.model) + model = model_class( + frontend=frontend, + encoder=encoder, + encoder_decoder_attractor=encoder_decoder_attractor, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + device: str = "cpu", + ): + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + cmvn_file: The cmvn file for front-end + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + if model_file is not None: + if device == "cuda": + device = f"cuda:{torch.cuda.current_device()}" + checkpoint = torch.load(model_file, map_location=device) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + model.to(device) + return model, args