From 8762d9973585fdceaaa886516a06e0ada303d3b5 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 15:30:17 +0800 Subject: [PATCH 01/11] update ola --- funasr/models/frontend/wav_frontend.py | 111 ++++++++++++++++++++----- 1 file changed, 88 insertions(+), 23 deletions(-) diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index ed8cb3646..4e52b90f1 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -1,14 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from espnet/espnet. -from typing import Tuple - +import funasr.models.frontend.eend_ola_feature import numpy as np import torch import torchaudio.compliance.kaldi as kaldi from funasr.models.frontend.abs_frontend import AbsFrontend -from typeguard import check_argument_types from torch.nn.utils.rnn import pad_sequence +from typeguard import check_argument_types +from typing import Tuple def load_cmvn(cmvn_file): @@ -33,9 +33,9 @@ def load_cmvn(cmvn_file): means = np.array(means_list).astype(np.float) vars = np.array(vars_list).astype(np.float) cmvn = np.array([means, vars]) - cmvn = torch.as_tensor(cmvn) - return cmvn - + cmvn = torch.as_tensor(cmvn) + return cmvn + def apply_cmvn(inputs, cmvn_file): # noqa """ @@ -78,21 +78,22 @@ def apply_lfr(inputs, lfr_m, lfr_n): class WavFrontend(AbsFrontend): """Conventional frontend structure for ASR. """ + def __init__( - self, - cmvn_file: str = None, - fs: int = 16000, - window: str = 'hamming', - n_mels: int = 80, - frame_length: int = 25, - frame_shift: int = 10, - filter_length_min: int = -1, - filter_length_max: int = -1, - lfr_m: int = 1, - lfr_n: int = 1, - dither: float = 1.0, - snip_edges: bool = True, - upsacle_samples: bool = True, + self, + cmvn_file: str = None, + fs: int = 16000, + window: str = 'hamming', + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: int = -1, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, ): assert check_argument_types() super().__init__() @@ -135,11 +136,11 @@ class WavFrontend(AbsFrontend): window_type=self.window, sample_frequency=self.fs, snip_edges=self.snip_edges) - + if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) if self.cmvn_file is not None: - mat = apply_cmvn(mat, self.cmvn_file) + mat = apply_cmvn(mat, self.cmvn_file) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) @@ -171,7 +172,6 @@ class WavFrontend(AbsFrontend): window_type=self.window, sample_frequency=self.fs) - feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) @@ -204,3 +204,68 @@ class WavFrontend(AbsFrontend): batch_first=True, padding_value=0.0) return feats_pad, feats_lens + + +class WavFrontendMel23(AbsFrontend): + """Conventional frontend structure for ASR. + """ + + def __init__( + self, + fs: int = 16000, + window: str = 'hamming', + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: int = -1, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, + ): + assert check_argument_types() + super().__init__() + self.fs = fs + self.window = window + self.n_mels = n_mels + self.frame_length = frame_length + self.frame_shift = frame_shift + self.filter_length_min = filter_length_min + self.filter_length_max = filter_length_max + self.lfr_m = lfr_m + self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples + + def output_size(self) -> int: + return self.n_mels * self.lfr_m + + def forward( + self, + input: torch.Tensor, + input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform_length = input_lengths[i] + waveform = input[i][:waveform_length] + waveform = waveform.unsqueeze(0).numpy() + mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift) + mat = eend_ola_feature.transform(mat) + mat = mat.splice(mat, context_size=self.lfr_m) + mat = mat[::self.lfr_n] + mat = torch.from_numpy(mat) + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + feats_lens = torch.as_tensor(feats_lens) + feats_pad = pad_sequence(feats, + batch_first=True, + padding_value=0.0) + return feats_pad, feats_lens From b7b65c844d6d7b88b76270f0c29841c6ea321175 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 15:33:23 +0800 Subject: [PATCH 02/11] update ola --- funasr/models/frontend/eend_ola_feature.py | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 funasr/models/frontend/eend_ola_feature.py diff --git a/funasr/models/frontend/eend_ola_feature.py b/funasr/models/frontend/eend_ola_feature.py new file mode 100644 index 000000000..e15b71c25 --- /dev/null +++ b/funasr/models/frontend/eend_ola_feature.py @@ -0,0 +1,51 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This module is for computing audio features + +import librosa +import numpy as np + + +def transform(Y, dtype=np.float32): + Y = np.abs(Y) + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + return Y.astype(dtype) + + +def subsample(Y, T, subsampling=1): + Y_ss = Y[::subsampling] + T_ss = T[::subsampling] + return Y_ss, T_ss + + +def splice(Y, context_size=0): + Y_pad = np.pad( + Y, + [(context_size, context_size), (0, 0)], + 'constant') + Y_spliced = np.lib.stride_tricks.as_strided( + np.ascontiguousarray(Y_pad), + (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), + (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) + return Y_spliced + + +def stft( + data, + frame_size=1024, + frame_shift=256): + fft_size = 1 << (frame_size - 1).bit_length() + if len(data) % frame_shift == 0: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T[:-1] + else: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T \ No newline at end of file From 3ff62dbb97684837f76c75f2defe6d0c77bc6d48 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 16:04:27 +0800 Subject: [PATCH 03/11] update ola --- funasr/models/e2e_diar_eend_ola.py | 395 +++++++++++++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 funasr/models/e2e_diar_eend_ola.py diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py new file mode 100644 index 000000000..967c0d487 --- /dev/null +++ b/funasr/models/e2e_diar_eend_ola.py @@ -0,0 +1,395 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +import torch +from contextlib import contextmanager +from distutils.version import LooseVersion +from funasr.layers.abs_normalize import AbsNormalize +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from funasr.models.ctc import CTC +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.modules.e2e_asr_common import ErrorCalculator +from funasr.modules.eend_ola.encoder import TransformerEncoder +from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor +from funasr.modules.eend_ola.utils.power import generate_mapping_dict +from funasr.modules.nets_utils import th_accuracy +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel +from typeguard import check_argument_types +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class DiarEENDOLAModel(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + encoder: TransformerEncoder, + eda: EncoderDecoderAttractor, + max_n_speaker: int = 8, + attractor_loss_weight: float = 1.0, + mapping_dict=None, + **kwargs, + ): + assert check_argument_types() + + super().__init__() + self.encoder = encoder + self.eda = eda + self.attractor_loss_weight = attractor_loss_weight + self.max_n_speaker = max_n_speaker + if mapping_dict is None: + mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) + self.mapping_dict = mapping_dict + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + + # Collect total loss stats + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + feats, feats_lengths = speech, speech_lengths + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc From b6126fd539df1be5f5e07993e68bd90e22a18e95 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 16:16:57 +0800 Subject: [PATCH 04/11] update ola --- funasr/models/e2e_diar_eend_ola.py | 322 +++++++---------------------- 1 file changed, 72 insertions(+), 250 deletions(-) diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 967c0d487..5c1c9ce4e 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -1,38 +1,24 @@ # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import logging -import torch from contextlib import contextmanager from distutils.version import LooseVersion -from funasr.layers.abs_normalize import AbsNormalize -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from funasr.models.ctc import CTC -from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.modules.add_sos_eos import add_sos_eos -from funasr.modules.e2e_asr_common import ErrorCalculator +from typing import Dict +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from typeguard import check_argument_types + from funasr.modules.eend_ola.encoder import TransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict -from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel -from typeguard import check_argument_types -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - from torch.cuda.amp import autocast + pass else: # Nothing to do if torch<1.6.0 @contextmanager @@ -47,6 +33,7 @@ class DiarEENDOLAModel(AbsESPnetModel): self, encoder: TransformerEncoder, eda: EncoderDecoderAttractor, + n_units: int = 256, max_n_speaker: int = 8, attractor_loss_weight: float = 1.0, mapping_dict=None, @@ -62,6 +49,9 @@ class DiarEENDOLAModel(AbsESPnetModel): if mapping_dict is None: mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) self.mapping_dict = mapping_dict + # PostNet + self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) + self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) def forward( self, @@ -163,233 +153,65 @@ class DiarEENDOLAModel(AbsESPnetModel): loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight - def collect_feats( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - if self.extract_feats_in_collect_stats: - feats, feats_lengths = self._extract_feats(speech, speech_lengths) + def estimate_sequential(self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + n_speakers: int, + shuffle: bool, + threshold: float, + **kwargs): + speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] + emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)] + if shuffle: + orders = [np.arange(e.shape[0]) for e in emb] + for order in orders: + np.random.shuffle(order) + # e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换 + # attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)] + # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ] + attractors, probs = self.eda.estimate( + [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)]) else: - # Generate dummy stats if extract_feats_in_collect_stats is False - logging.warning( - "Generating dummy stats for feats and feats_lengths, " - "because encoder_conf.extract_feats_in_collect_stats is " - f"{self.extract_feats_in_collect_stats}" - ) - feats, feats_lengths = speech, speech_lengths - return {"feats": feats, "feats_lengths": feats_lengths} + attractors, probs = self.eda.estimate(emb) + attractors_active = [] + for p, att, e in zip(probs, attractors, emb): + if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys + # TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定 + # raise NotImplementedError + att = att[:n_speakers, ] + attractors_active.append(att) + elif threshold is not None: + silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数 + n_spk = silence[0] if silence.size else None + att = att[:n_spk, ] + attractors_active.append(att) + else: + NotImplementedError('n_speakers or th has to be given.') + raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB] + attractors = [ + pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] + for att in attractors_active] + ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] + # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)] + logits = self.cal_postnet(ys, self.max_n_speaker) + ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in + zip(logits, raw_n_speakers)] - def encode( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Frontend + Encoder. Note that this method is used by asr_inference.py + return ys, emb, attractors, raw_n_speakers - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - """ - with autocast(False): - # 1. Extract feats - feats, feats_lengths = self._extract_feats(speech, speech_lengths) - - # 2. Data augmentation - if self.specaug is not None and self.training: - feats, feats_lengths = self.specaug(feats, feats_lengths) - - # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN - if self.normalize is not None: - feats, feats_lengths = self.normalize(feats, feats_lengths) - - # Pre-encoder, e.g. used for raw input data - if self.preencoder is not None: - feats, feats_lengths = self.preencoder(feats, feats_lengths) - - # 4. Forward encoder - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder( - feats, feats_lengths, ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - - # Post-encoder, e.g. NLU - if self.postencoder is not None: - encoder_out, encoder_out_lens = self.postencoder( - encoder_out, encoder_out_lens - ) - - assert encoder_out.size(0) == speech.size(0), ( - encoder_out.size(), - speech.size(0), - ) - assert encoder_out.size(1) <= encoder_out_lens.max(), ( - encoder_out.size(), - encoder_out_lens.max(), - ) - - if intermediate_outs is not None: - return (encoder_out, intermediate_outs), encoder_out_lens - - return encoder_out, encoder_out_lens - - def _extract_feats( - self, speech: torch.Tensor, speech_lengths: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - assert speech_lengths.dim() == 1, speech_lengths.shape - - # for data-parallel - speech = speech[:, : speech_lengths.max()] - - if self.frontend is not None: - # Frontend - # e.g. STFT and Feature extract - # data_loader may send time-domain signal in this case - # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) - feats, feats_lengths = self.frontend(speech, speech_lengths) - else: - # No frontend and no feature extract - feats, feats_lengths = speech, speech_lengths - return feats, feats_lengths - - def nll( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ) -> torch.Tensor: - """Compute negative log likelihood(nll) from transformer-decoder - - Normally, this function is called in batchify_nll. - - Args: - encoder_out: (Batch, Length, Dim) - encoder_out_lens: (Batch,) - ys_pad: (Batch, Length) - ys_pad_lens: (Batch,) - """ - ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_in_lens = ys_pad_lens + 1 - - # 1. Forward decoder - decoder_out, _ = self.decoder( - encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens - ) # [batch, seqlen, dim] - batch_size = decoder_out.size(0) - decoder_num_class = decoder_out.size(2) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - decoder_out.view(-1, decoder_num_class), - ys_out_pad.view(-1), - ignore_index=self.ignore_id, - reduction="none", - ) - nll = nll.view(batch_size, -1) - nll = nll.sum(dim=1) - assert nll.size(0) == batch_size - return nll - - def batchify_nll( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - batch_size: int = 100, - ): - """Compute negative log likelihood(nll) from transformer-decoder - - To avoid OOM, this fuction seperate the input into batches. - Then call nll for each batch and combine and return results. - Args: - encoder_out: (Batch, Length, Dim) - encoder_out_lens: (Batch,) - ys_pad: (Batch, Length) - ys_pad_lens: (Batch,) - batch_size: int, samples each batch contain when computing nll, - you may change this to avoid OOM or increase - GPU memory usage - """ - total_num = encoder_out.size(0) - if total_num <= batch_size: - nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) - else: - nll = [] - start_idx = 0 - while True: - end_idx = min(start_idx + batch_size, total_num) - batch_encoder_out = encoder_out[start_idx:end_idx, :, :] - batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] - batch_ys_pad = ys_pad[start_idx:end_idx, :] - batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] - batch_nll = self.nll( - batch_encoder_out, - batch_encoder_out_lens, - batch_ys_pad, - batch_ys_pad_lens, - ) - nll.append(batch_nll) - start_idx = end_idx - if start_idx == total_num: - break - nll = torch.cat(nll) - assert nll.size(0) == total_num - return nll - - def _calc_att_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) - ys_in_lens = ys_pad_lens + 1 - - # 1. Forward decoder - decoder_out, _ = self.decoder( - encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens - ) - - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_out_pad) - acc_att = th_accuracy( - decoder_out.view(-1, self.vocab_size), - ys_out_pad, - ignore_label=self.ignore_id, - ) - - # Compute cer/wer using attention-decoder - if self.training or self.error_calculator is None: - cer_att, wer_att = None, None - else: - ys_hat = decoder_out.argmax(dim=-1) - cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) - - return loss_att, acc_att, cer_att, wer_att - - def _calc_ctc_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, - ): - # Calc CTC loss - loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) - - # Calc CER using CTC - cer_ctc = None - if not self.training and self.error_calculator is not None: - ys_hat = self.ctc.argmax(encoder_out).data - cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) - return loss_ctc, cer_ctc + def recover_y_from_powerlabel(self, logit, n_speaker): + pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, ) + oov_index = torch.where(pred == self.mapping_dict['oov'])[0] + for i in oov_index: + if i > 0: + pred[i] = pred[i - 1] + else: + pred[i] = 0 + pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] + # print(pred) + decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] + decisions = torch.from_numpy( + np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( + torch.float32) + decisions = decisions[:, :n_speaker] + return decisions From 229efa6250927485a3e1018548630b6348a24d1c Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 16:25:53 +0800 Subject: [PATCH 05/11] update ola --- funasr/models/e2e_diar_eend_ola.py | 56 ++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 5c1c9ce4e..2960b23ca 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -26,6 +26,13 @@ else: yield +def pad_attractor(att, max_n_speakers): + C, D = att.shape + if C < max_n_speakers: + att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0) + return att + + class DiarEENDOLAModel(AbsESPnetModel): """CTC-attention hybrid Encoder-Decoder model""" @@ -53,6 +60,26 @@ class DiarEENDOLAModel(AbsESPnetModel): self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) + def forward_encoder(self, xs, ilens): + xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) + pad_shape = xs.shape + xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens] + xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2) + emb = self.encoder(xs, xs_mask) + emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0) + emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)] + return emb + + def forward_post_net(self, logits, ilens): + maxlen = torch.max(ilens).to(torch.int).item() + logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) + logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) + outputs, (_, _) = self.PostNet(logits) + outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] + outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] + outputs = [self.output_layer(output) for output in outputs] + return outputs + def forward( self, speech: torch.Tensor, @@ -156,51 +183,45 @@ class DiarEENDOLAModel(AbsESPnetModel): def estimate_sequential(self, speech: torch.Tensor, speech_lengths: torch.Tensor, - n_speakers: int, - shuffle: bool, - threshold: float, + n_speakers: int = None, + shuffle: bool = True, + threshold: float = 0.5, **kwargs): speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] - emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)] + emb = self.forward_encoder(speech, speech_lengths) if shuffle: orders = [np.arange(e.shape[0]) for e in emb] for order in orders: np.random.shuffle(order) - # e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换 - # attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)] - # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ] attractors, probs = self.eda.estimate( - [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)]) + [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) else: attractors, probs = self.eda.estimate(emb) attractors_active = [] for p, att, e in zip(probs, attractors, emb): - if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys - # TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定 - # raise NotImplementedError + if n_speakers and n_speakers >= 0: att = att[:n_speakers, ] attractors_active.append(att) elif threshold is not None: - silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数 + silence = torch.nonzero(p < threshold)[0] n_spk = silence[0] if silence.size else None att = att[:n_spk, ] attractors_active.append(att) else: - NotImplementedError('n_speakers or th has to be given.') - raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB] + NotImplementedError('n_speakers or threshold has to be given.') + raw_n_speakers = [att.shape[0] for att in attractors_active] attractors = [ pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] for att in attractors_active] ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] - # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)] - logits = self.cal_postnet(ys, self.max_n_speaker) + logits = self.forward_post_net(ys, speech_lengths) ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in zip(logits, raw_n_speakers)] return ys, emb, attractors, raw_n_speakers def recover_y_from_powerlabel(self, logit, n_speaker): - pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, ) + pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) oov_index = torch.where(pred == self.mapping_dict['oov'])[0] for i in oov_index: if i > 0: @@ -208,7 +229,6 @@ class DiarEENDOLAModel(AbsESPnetModel): else: pred[i] = 0 pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] - # print(pred) decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] decisions = torch.from_numpy( np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( From e27de5aa6bd9af2a82e80604978b50aa538493ec Mon Sep 17 00:00:00 2001 From: speech_asr Date: Mon, 13 Mar 2023 18:45:27 +0800 Subject: [PATCH 06/11] update ola --- funasr/models/e2e_diar_eend_ola.py | 19 +- funasr/modules/eend_ola/encoder.py | 16 +- funasr/tasks/diar.py | 327 +++++++++++++++++++++++++++-- 3 files changed, 335 insertions(+), 27 deletions(-) diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 2960b23ca..f589269c5 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -11,7 +11,8 @@ import torch import torch.nn as nn from typeguard import check_argument_types -from funasr.modules.eend_ola.encoder import TransformerEncoder +from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict from funasr.torch_utils.device_funcs import force_gatherable @@ -34,12 +35,13 @@ def pad_attractor(att, max_n_speakers): class DiarEENDOLAModel(AbsESPnetModel): - """CTC-attention hybrid Encoder-Decoder model""" + """EEND-OLA diarization model""" def __init__( self, - encoder: TransformerEncoder, - eda: EncoderDecoderAttractor, + frontend: WavFrontendMel23, + encoder: EENDOLATransformerEncoder, + encoder_decoder_attractor: EncoderDecoderAttractor, n_units: int = 256, max_n_speaker: int = 8, attractor_loss_weight: float = 1.0, @@ -49,8 +51,9 @@ class DiarEENDOLAModel(AbsESPnetModel): assert check_argument_types() super().__init__() + self.frontend = frontend self.encoder = encoder - self.eda = eda + self.encoder_decoder_attractor = encoder_decoder_attractor self.attractor_loss_weight = attractor_loss_weight self.max_n_speaker = max_n_speaker if mapping_dict is None: @@ -187,16 +190,18 @@ class DiarEENDOLAModel(AbsESPnetModel): shuffle: bool = True, threshold: float = 0.5, **kwargs): + if self.frontend is not None: + speech = self.frontend(speech) speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] emb = self.forward_encoder(speech, speech_lengths) if shuffle: orders = [np.arange(e.shape[0]) for e in emb] for order in orders: np.random.shuffle(order) - attractors, probs = self.eda.estimate( + attractors, probs = self.encoder_decoder_attractor.estimate( [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) else: - attractors, probs = self.eda.estimate(emb) + attractors, probs = self.encoder_decoder_attractor.estimate(emb) attractors_active = [] for p, att, e in zip(probs, attractors, emb): if n_speakers and n_speakers >= 0: diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py index 17d11ace7..4999031b1 100644 --- a/funasr/modules/eend_ola/encoder.py +++ b/funasr/modules/eend_ola/encoder.py @@ -1,5 +1,5 @@ import math -import numpy as np + import torch import torch.nn.functional as F from torch import nn @@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module): return self.dropout(x) -class TransformerEncoder(nn.Module): - def __init__(self, idim, n_layers, n_units, - e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False): - super(TransformerEncoder, self).__init__() +class EENDOLATransformerEncoder(nn.Module): + def __init__(self, + idim: int, + n_layers: int, + n_units: int, + e_units: int = 2048, + h: int = 8, + dropout_rate: float = 0.1, + use_pos_emb: bool = False): + super(EENDOLATransformerEncoder, self).__init__() self.lnorm_in = nn.LayerNorm(n_units) self.n_layers = n_layers self.dropout = nn.Dropout(dropout_rate) diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index e699dccb0..953ab82c8 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -20,19 +20,18 @@ from funasr.datasets.collate_fn import CommonCollateFn from funasr.datasets.preprocessor import CommonPreprocessor from funasr.layers.abs_normalize import AbsNormalize from funasr.layers.global_mvn import GlobalMVN -from funasr.layers.utterance_mvn import UtteranceMVN from funasr.layers.label_aggregation import LabelAggregate -from funasr.models.ctc import CTC -from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar -from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN -from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder -from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder -from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder -from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.layers.utterance_mvn import UtteranceMVN from funasr.models.e2e_diar_sond import DiarSondModel from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN +from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder +from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder +from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.transformer_encoder import TransformerEncoder @@ -41,17 +40,13 @@ from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.fused import FusedFrontends from funasr.models.frontend.s3prl import S3prlFrontend from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.models.frontend.windowing import SlidingWindow -from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.postencoder.hugging_face_transformers_postencoder import ( - HuggingFaceTransformersPostEncoder, # noqa: H301 -) -from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.preencoder.linear import LinearProjection -from funasr.models.preencoder.sinc import LightweightSincConvs from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR +from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder +from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.tasks.abs_task import AbsTask from funasr.torch_utils.initialize import initialize from funasr.train.abs_espnet_model import AbsESPnetModel @@ -70,6 +65,7 @@ frontend_choices = ClassChoices( s3prl=S3prlFrontend, fused=FusedFrontends, wav_frontend=WavFrontend, + wav_frontend_mel23=WavFrontendMel23, ), type_check=AbsFrontend, default="default", @@ -126,6 +122,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, ecapa_tdnn=ECAPA_TDNN, + eend_ola_transformer=EENDOLATransformerEncoder, ), type_check=torch.nn.Module, default="resnet34", @@ -177,6 +174,15 @@ decoder_choices = ClassChoices( type_check=torch.nn.Module, default="fsmn", ) +# encoder_decoder_attractor is used for EEND-OLA +encoder_decoder_attractor_choices = ClassChoices( + "encoder_decoder_attractor", + classes=dict( + eda=EncoderDecoderAttractor, + ), + type_check=torch.nn.Module, + default="eda", +) class DiarTask(AbsTask): @@ -594,3 +600,294 @@ class DiarTask(AbsTask): var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update + + +class EENDOLADiarTask(AbsTask): + # If you need more than 1 optimizer, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + model_choices, + # --encoder and --encoder_conf + encoder_choices, + # --speaker_encoder and --speaker_encoder_conf + encoder_decoder_attractor_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="char", + choices=["char"], + help="The text will be tokenized in the specified level token", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=None, + non_linguistic_symbols=None, + text_cleaner=None, + g2p_type=None, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "profile", "binary_labels") + else: + # Recognition mode + retval = ("speech") + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + + # 1. frontend + if args.input_size is None or args.frontend == "wav_frontend_mel23": + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 3. EncoderDecoderAttractor + encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) + encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) + + # 9. Build model + model_class = model_choices.get_class(args.model) + model = model_class( + frontend=frontend, + encoder=encoder, + encoder_decoder_attractor=encoder_decoder_attractor, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + device: str = "cpu", + ): + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + cmvn_file: The cmvn file for front-end + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + if model_file is not None: + if device == "cuda": + device = f"cuda:{torch.cuda.current_device()}" + checkpoint = torch.load(model_file, map_location=device) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + model.to(device) + return model, args From 7de11ad9efa625b716730ef8dbbd9aa63b6c7dc3 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 14 Mar 2023 10:36:20 +0800 Subject: [PATCH 07/11] update ola --- funasr/bin/eend_ola_inference.py | 413 +++++++++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100755 funasr/bin/eend_ola_inference.py diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py new file mode 100755 index 000000000..d191877ac --- /dev/null +++ b/funasr/bin/eend_ola_inference.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types + +from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.tasks.diar import EENDOLADiarTask +from funasr.torch_utils.device_funcs import to_device +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none + + +class Speech2Diarization: + """Speech2Diarlization class + + Examples: + >>> import soundfile + >>> import numpy as np + >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth") + >>> profile = np.load("profiles.npy") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2diar(audio, profile) + {"spk1": [(int, int), ...], ...} + + """ + + def __init__( + self, + diar_train_config: Union[Path, str] = None, + diar_model_file: Union[Path, str] = None, + device: str = "cpu", + dtype: str = "float32", + ): + assert check_argument_types() + + # 1. Build Diarization model + diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file( + config_file=diar_train_config, + model_file=diar_model_file, + device=device + ) + frontend = None + if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None: + frontend = WavFrontendMel23(**diar_train_args.frontend_conf) + + # set up seed for eda + np.random.seed(diar_train_args.seed) + torch.manual_seed(diar_train_args.seed) + torch.cuda.manual_seed(diar_train_args.seed) + os.environ['PYTORCH_SEED'] = str(diar_train_args.seed) + logging.info("diar_model: {}".format(diar_model)) + logging.info("diar_train_args: {}".format(diar_train_args)) + diar_model.to(dtype=getattr(torch, dtype)).eval() + + self.diar_model = diar_model + self.diar_train_args = diar_train_args + self.device = device + self.dtype = dtype + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, + speech: Union[torch.Tensor, np.ndarray], + speech_lengths: Union[torch.Tensor, np.ndarray] = None + ): + """Inference + + Args: + speech: Input speech data + Returns: + diarization results + + """ + assert check_argument_types() + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.diar_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + batch = {"speech": feats, "speech_lengths": feats_len} + batch = to_device(batch, device=self.device) + results = self.diar_model.estimate_sequential(**batch) + + return results + + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ): + """Build Speech2Diarization instance from the pretrained model. + + Args: + model_tag (Optional[str]): Model tag of the pretrained models. + Currently, the tags of espnet_model_zoo are supported. + + Returns: + Speech2Xvector: Speech2Xvector instance. + + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Diarization(**kwargs) + + +def inference_modelscope( + diar_train_config: str, + diar_model_file: str, + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + num_workers: int = 0, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + param_dict: Optional[dict] = None, + **kwargs, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("param_dict: {}".format(param_dict)) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Build speech2diar + speech2diar_kwargs = dict( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + device=device, + dtype=dtype, + streaming=streaming, + ) + logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) + speech2diar = Speech2Diarization.from_pretrained( + model_tag=model_tag, + **speech2diar_kwargs, + ) + speech2diar.diar_model.eval() + + def output_results_str(results: dict, uttid: str): + rst = [] + mid = uttid.rsplit("-", 1)[0] + for key in results: + results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]] + template = "SPEAKER {} 0 {:.2f} {:.2f} {} " + for spk, segs in results.items(): + rst.extend([template.format(mid, st, ed, spk) for st, ed in segs]) + + return "\n".join(rst) + + def _forward( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, + raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None, + output_dir_v2: Optional[str] = None, + param_dict: Optional[dict] = None, + ): + # 2. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = EENDOLADiarTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False), + collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 3. Start for-loop + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + output_writer = open("{}/result.txt".format(output_path), "w") + result_list = [] + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + results = speech2diar(**batch) + # Only supporting batch_size==1 + key, value = keys[0], output_results_str(results, keys[0]) + item = {"key": key, "value": value} + result_list.append(item) + if output_path is not None: + output_writer.write(value) + output_writer.flush() + + if output_path is not None: + output_writer.close() + + return result_list + + return _forward + + +def inference( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + diar_train_config: Optional[str], + diar_model_file: Optional[str], + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + seed: int = 0, + num_workers: int = 1, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + smooth_size: int = 83, + dur_threshold: int = 10, + out_format: str = "vad", + **kwargs, +): + inference_pipeline = inference_modelscope( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + output_dir=output_dir, + batch_size=batch_size, + dtype=dtype, + ngpu=ngpu, + seed=seed, + num_workers=num_workers, + log_level=log_level, + key_file=key_file, + model_tag=model_tag, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + smooth_size=smooth_size, + dur_threshold=dur_threshold, + out_format=out_format, + **kwargs, + ) + + return inference_pipeline(data_path_and_name_and_type, raw_inputs=None) + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speaker verification/x-vector extraction", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=False) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--diar_train_config", + type=str, + help="diarization training configuration", + ) + group.add_argument( + "--diar_model_file", + type=str, + help="diarization model parameter file", + ) + group.add_argument( + "--dur_threshold", + type=int, + default=10, + help="The threshold for short segments in number frames" + ) + parser.add_argument( + "--smooth_size", + type=int, + default=83, + help="The smoothing window length in number frames" + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument("--streaming", type=str2bool, default=False) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + logging.info("args: {}".format(kwargs)) + if args.output_dir is None: + jobid, n_gpu = 1, 1 + gpuid = args.gpuid_list.split(",")[jobid - 1] + else: + jobid = int(args.output_dir.split(".")[-1]) + n_gpu = len(args.gpuid_list.split(",")) + gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + results_list = inference(**kwargs) + for results in results_list: + print("{} {}".format(results["key"], results["value"])) + + +if __name__ == "__main__": + main() From ad2ef723410927bcb89494478bf8f2defa9f76b5 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 14 Mar 2023 11:33:20 +0800 Subject: [PATCH 08/11] update ola --- funasr/bin/eend_ola_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py index d191877ac..d65895f30 100755 --- a/funasr/bin/eend_ola_inference.py +++ b/funasr/bin/eend_ola_inference.py @@ -121,7 +121,7 @@ class Speech2Diarization: Currently, the tags of espnet_model_zoo are supported. Returns: - Speech2Xvector: Speech2Xvector instance. + Speech2Diarization: Speech2Diarization instance. """ if model_tag is not None: From 141a4737f779fcf435a0ece5434b9c73eda7d2a9 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 14 Mar 2023 15:54:28 +0800 Subject: [PATCH 09/11] update --- funasr/models/frontend/wav_frontend.py | 16 +--------------- funasr/tasks/diar.py | 2 ++ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 4e52b90f1..6af707478 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -6,6 +6,7 @@ import numpy as np import torch import torchaudio.compliance.kaldi as kaldi from funasr.models.frontend.abs_frontend import AbsFrontend +import funasr.models.frontend.eend_ola_feature as eend_ola_feature from torch.nn.utils.rnn import pad_sequence from typeguard import check_argument_types from typing import Tuple @@ -213,33 +214,18 @@ class WavFrontendMel23(AbsFrontend): def __init__( self, fs: int = 16000, - window: str = 'hamming', - n_mels: int = 80, frame_length: int = 25, frame_shift: int = 10, - filter_length_min: int = -1, - filter_length_max: int = -1, lfr_m: int = 1, lfr_n: int = 1, - dither: float = 1.0, - snip_edges: bool = True, - upsacle_samples: bool = True, ): assert check_argument_types() super().__init__() self.fs = fs - self.window = window - self.n_mels = n_mels self.frame_length = frame_length self.frame_shift = frame_shift - self.filter_length_min = filter_length_min - self.filter_length_max = filter_length_max self.lfr_m = lfr_m self.lfr_n = lfr_n - self.cmvn_file = cmvn_file - self.dither = dither - self.snip_edges = snip_edges - self.upsacle_samples = upsacle_samples def output_size(self) -> int: return self.n_mels * self.lfr_m diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index 953ab82c8..ae7ee9b40 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -23,6 +23,7 @@ from funasr.layers.global_mvn import GlobalMVN from funasr.layers.label_aggregation import LabelAggregate from funasr.layers.utterance_mvn import UtteranceMVN from funasr.models.e2e_diar_sond import DiarSondModel +from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.conformer_encoder import ConformerEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder @@ -103,6 +104,7 @@ model_choices = ClassChoices( "model", classes=dict( sond=DiarSondModel, + eend_ola=DiarEENDOLAModel, ), type_check=AbsESPnetModel, default="sond", From 64b591eb6f6f8b8e80d6e94c00f2770a386b15cd Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 14 Mar 2023 16:20:24 +0800 Subject: [PATCH 10/11] update --- funasr/models/frontend/wav_frontend.py | 223 +++++++++++++++++++++++-- 1 file changed, 206 insertions(+), 17 deletions(-) diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 6af707478..0bf5ce16d 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -1,15 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from espnet/espnet. +from abc import ABC +from typing import Tuple -import funasr.models.frontend.eend_ola_feature import numpy as np import torch import torchaudio.compliance.kaldi as kaldi from funasr.models.frontend.abs_frontend import AbsFrontend -import funasr.models.frontend.eend_ola_feature as eend_ola_feature -from torch.nn.utils.rnn import pad_sequence from typeguard import check_argument_types -from typing import Tuple +from torch.nn.utils.rnn import pad_sequence def load_cmvn(cmvn_file): @@ -207,51 +206,241 @@ class WavFrontend(AbsFrontend): return feats_pad, feats_lens -class WavFrontendMel23(AbsFrontend): - """Conventional frontend structure for ASR. +class WavFrontendOnline(AbsFrontend): + """Conventional frontend structure for streaming ASR/VAD. """ def __init__( self, + cmvn_file: str = None, fs: int = 16000, + window: str = 'hamming', + n_mels: int = 80, frame_length: int = 25, frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: int = -1, lfr_m: int = 1, lfr_n: int = 1, + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, ): assert check_argument_types() super().__init__() self.fs = fs + self.window = window + self.n_mels = n_mels self.frame_length = frame_length self.frame_shift = frame_shift + self.frame_sample_length = int(self.frame_length * self.fs / 1000) + self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000) + self.filter_length_min = filter_length_min + self.filter_length_max = filter_length_max self.lfr_m = lfr_m self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples + self.waveforms = None + self.reserve_waveforms = None + self.fbanks = None + self.fbanks_lens = None + self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) + self.input_cache = None + self.lfr_splice_cache = [] def output_size(self) -> int: return self.n_mels * self.lfr_m - def forward( + @staticmethod + def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor: + """ + Apply CMVN with mvn data + """ + + device = inputs.device + dtype = inputs.dtype + frame, dim = inputs.shape + + means = np.tile(cmvn[0:1, :dim], (frame, 1)) + vars = np.tile(cmvn[1:2, :dim], (frame, 1)) + inputs += torch.from_numpy(means).type(dtype).to(device) + inputs *= torch.from_numpy(vars).type(dtype).to(device) + + return inputs.type(torch.float32) + + @staticmethod + # inputs tensor has catted the cache tensor + # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None, + # is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]: + def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Apply lfr with data + """ + + LFR_inputs = [] + # inputs = torch.vstack((inputs_lfr_cache, inputs)) + T = inputs.shape[0] # include the right context + T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2 + splice_idx = T_lfr + for i in range(T_lfr): + if lfr_m <= T - i * lfr_n: + LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1)) + else: # process last LFR frame + if is_final: + num_padding = lfr_m - (T - i * lfr_n) + frame = (inputs[i * lfr_n:]).view(-1) + for _ in range(num_padding): + frame = torch.hstack((frame, inputs[-1])) + LFR_inputs.append(frame) + else: + # update splice_idx and break the circle + splice_idx = i + break + splice_idx = min(T - 1, splice_idx * lfr_n) + lfr_splice_cache = inputs[splice_idx:, :] + LFR_outputs = torch.vstack(LFR_inputs) + return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx + + @staticmethod + def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int: + frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1) + return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0 + + def forward_fbank( self, input: torch.Tensor, - input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + if self.input_cache is None: + self.input_cache = torch.empty(0) + input = torch.cat((self.input_cache, input), dim=1) + frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length) + # update self.in_cache + self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):] + waveforms = torch.empty(0) + feats_pad = torch.empty(0) + feats_lens = torch.empty(0) + if frame_num: + waveforms = [] + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform = input[i] + # we need accurate wave samples that used for fbank extracting + waveforms.append( + waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)]) + waveform = waveform * (1 << 15) + waveform = waveform.unsqueeze(0) + mat = kaldi.fbank(waveform, + num_mel_bins=self.n_mels, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + energy_floor=0.0, + window_type=self.window, + sample_frequency=self.fs) + + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + waveforms = torch.stack(waveforms) + feats_lens = torch.as_tensor(feats_lens) + feats_pad = pad_sequence(feats, + batch_first=True, + padding_value=0.0) + self.fbanks = feats_pad + import copy + self.fbanks_lens = copy.deepcopy(feats_lens) + return waveforms, feats_pad, feats_lens + + def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.fbanks, self.fbanks_lens + + def forward_lfr_cmvn( + self, + input: torch.Tensor, + input_lengths: torch.Tensor, + is_final: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = input.size(0) feats = [] feats_lens = [] + lfr_splice_frame_idxs = [] for i in range(batch_size): - waveform_length = input_lengths[i] - waveform = input[i][:waveform_length] - waveform = waveform.unsqueeze(0).numpy() - mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift) - mat = eend_ola_feature.transform(mat) - mat = mat.splice(mat, context_size=self.lfr_m) - mat = mat[::self.lfr_n] - mat = torch.from_numpy(mat) + mat = input[i, :input_lengths[i], :] + if self.lfr_m != 1 or self.lfr_n != 1: + # update self.lfr_splice_cache in self.apply_lfr + # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i], + mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final) + if self.cmvn_file is not None: + mat = self.apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) + lfr_splice_frame_idxs.append(lfr_splice_frame_idx) feats_lens = torch.as_tensor(feats_lens) feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) - return feats_pad, feats_lens + lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs) + return feats_pad, feats_lens, lfr_splice_frame_idxs + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = input.shape[0] + assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now' + waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D + if feats.shape[0]: + #if self.reserve_waveforms is None and self.lfr_m > 1: + # self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length] + self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1) + if not self.lfr_splice_cache: # 初始化splice_cache + for i in range(batch_size): + self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)) + # need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m + if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m: + lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D + feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1) + feats_lengths += lfr_splice_cache_tensor[0].shape[0] + frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1) + minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0 + feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final) + if self.lfr_m == 1: + self.reserve_waveforms = None + else: + reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame + # print('reserve_frame_idx: ' + str(reserve_frame_idx)) + # print('frame_frame: ' + str(frame_from_waveforms)) + self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length] + sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length + self.waveforms = self.waveforms[:, :sample_length] + else: + # update self.reserve_waveforms and self.lfr_splice_cache + self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)] + for i in range(batch_size): + self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0) + return torch.empty(0), feats_lengths + else: + if is_final: + self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms + feats = torch.stack(self.lfr_splice_cache) + feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1] + feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final) + if is_final: + self.cache_reset() + return feats, feats_lengths + + def get_waveforms(self): + return self.waveforms + + def cache_reset(self): + self.reserve_waveforms = None + self.input_cache = None + self.lfr_splice_cache = [] \ No newline at end of file From ee1b0ec605fc0c077867dfb6e1cbb65614eb1347 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 14 Mar 2023 16:37:36 +0800 Subject: [PATCH 11/11] update --- funasr/models/frontend/wav_frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 0bf5ce16d..445efca24 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -430,7 +430,7 @@ class WavFrontendOnline(AbsFrontend): else: if is_final: self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms - feats = torch.stack(self.lfr_splice_cache) + feats = torch.stack(self.lfr_splice_cache) feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1] feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final) if is_final: @@ -443,4 +443,4 @@ class WavFrontendOnline(AbsFrontend): def cache_reset(self): self.reserve_waveforms = None self.input_cache = None - self.lfr_splice_cache = [] \ No newline at end of file + self.lfr_splice_cache = []