From ff0310bfb4ed69f00cbeab89a58f958ae5091d70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 6 Jul 2023 16:24:35 +0800 Subject: [PATCH 01/42] update eend_ola --- funasr/build_utils/build_args.py | 6 + funasr/build_utils/build_dataloader.py | 17 +- funasr/build_utils/build_diar_model.py | 6 +- .../small_datasets/sequence_iter_factory.py | 4 +- funasr/models/e2e_diar_eend_ola.py | 167 ++++++++---------- .../modules/eend_ola/eend_ola_dataloader.py | 57 ++++++ funasr/modules/eend_ola/encoder.py | 20 +-- funasr/modules/eend_ola/utils/losses.py | 77 +++----- 8 files changed, 184 insertions(+), 170 deletions(-) create mode 100644 funasr/modules/eend_ola/eend_ola_dataloader.py diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py index 632c13481..31f210eba 100644 --- a/funasr/build_utils/build_args.py +++ b/funasr/build_utils/build_args.py @@ -86,6 +86,12 @@ def build_args(args, parser, extra_task_params): from funasr.build_utils.build_diar_model import class_choices_list for class_choices in class_choices_list: class_choices.add_arguments(task_parser) + task_parser.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) elif args.task_name == "sv": from funasr.build_utils.build_sv_model import class_choices_list diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py index c95c40d8c..473097eda 100644 --- a/funasr/build_utils/build_dataloader.py +++ b/funasr/build_utils/build_dataloader.py @@ -4,8 +4,21 @@ from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFac def build_dataloader(args): if args.dataset_type == "small": - train_iter_factory = SequenceIterFactory(args, mode="train") - valid_iter_factory = SequenceIterFactory(args, mode="valid") + if args.task_name == "diar" and args.model == "eend_ola": + from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader + train_iter_factory = EENDOLADataLoader( + data_file=args.train_data_path_and_name_and_type[0][0], + batch_size=args.dataset_conf["batch_conf"]["batch_size"], + num_workers=args.dataset_conf["num_workers"], + shuffle=True) + valid_iter_factory = EENDOLADataLoader( + data_file=args.valid_data_path_and_name_and_type[0][0], + batch_size=args.dataset_conf["batch_conf"]["batch_size"], + num_workers=0, + shuffle=False) + else: + train_iter_factory = SequenceIterFactory(args, mode="train") + valid_iter_factory = SequenceIterFactory(args, mode="valid") elif args.dataset_type == "large": train_iter_factory = LargeDataLoader(args, mode="train") valid_iter_factory = LargeDataLoader(args, mode="valid") diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py index 0ea31270e..444636a02 100644 --- a/funasr/build_utils/build_diar_model.py +++ b/funasr/build_utils/build_diar_model.py @@ -198,16 +198,14 @@ def build_diar_model(args): frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) else: frontend = frontend_class(**args.frontend_conf) - input_size = frontend.output_size() else: args.frontend = None args.frontend_conf = {} frontend = None - input_size = args.input_size # encoder encoder_class = encoder_choices.get_class(args.encoder) - encoder = encoder_class(input_size=input_size, **args.encoder_conf) + encoder = encoder_class(**args.encoder_conf) if args.model == "sond": # data augmentation for spectrogram @@ -272,7 +270,7 @@ def build_diar_model(args): **args.model_conf, ) - elif args.model_name == "eend_ola": + elif args.model == "eend_ola": # encoder-decoder attractor encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py index 3ebcc5ac6..e748c3de5 100644 --- a/funasr/datasets/small_datasets/sequence_iter_factory.py +++ b/funasr/datasets/small_datasets/sequence_iter_factory.py @@ -57,7 +57,7 @@ class SequenceIterFactory(AbsIterFactory): data_path_and_name_and_type, preprocess=preprocess_fn, dest_sample_rate=dest_sample_rate, - speed_perturb=args.speed_perturb if mode=="train" else None, + speed_perturb=args.speed_perturb if mode == "train" else None, ) # sampler @@ -84,7 +84,7 @@ class SequenceIterFactory(AbsIterFactory): args.max_update = len(bs_list) * args.max_epoch logging.info("Max update: {}".format(args.max_update)) - if args.distributed and mode=="train": + if args.distributed and mode == "train": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() for batch in batches: diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index ae3a436e9..af0fd62c8 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -1,21 +1,21 @@ -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - from contextlib import contextmanager from distutils.version import LooseVersion -from typing import Dict -from typing import Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +from typeguard import check_argument_types +from funasr.models.base_model import FunASRModel from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor +from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss +from funasr.modules.eend_ola.utils.power import create_powerlabel from funasr.modules.eend_ola.utils.power import generate_mapping_dict from funasr.torch_utils.device_funcs import force_gatherable -from funasr.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): pass @@ -33,12 +33,35 @@ def pad_attractor(att, max_n_speakers): return att +def pad_labels(ts, out_size): + for i, t in enumerate(ts): + if t.shape[1] < out_size: + ts[i] = F.pad( + t, + (0, out_size - t.shape[1], 0, 0), + mode='constant', + value=0. + ) + return ts + + +def pad_results(ys, out_size): + ys_padded = [] + for i, y in enumerate(ys): + if y.shape[1] < out_size: + ys_padded.append( + torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1)) + else: + ys_padded.append(y) + return ys_padded + + class DiarEENDOLAModel(FunASRModel): """EEND-OLA diarization model""" def __init__( self, - frontend: WavFrontendMel23, + frontend: Optional[WavFrontendMel23], encoder: EENDOLATransformerEncoder, encoder_decoder_attractor: EncoderDecoderAttractor, n_units: int = 256, @@ -47,11 +70,12 @@ class DiarEENDOLAModel(FunASRModel): mapping_dict=None, **kwargs, ): + assert check_argument_types() super().__init__() self.frontend = frontend self.enc = encoder - self.eda = encoder_decoder_attractor + self.encoder_decoder_attractor = encoder_decoder_attractor self.attractor_loss_weight = attractor_loss_weight self.max_n_speaker = max_n_speaker if mapping_dict is None: @@ -74,7 +98,8 @@ class DiarEENDOLAModel(FunASRModel): def forward_post_net(self, logits, ilens): maxlen = torch.max(ilens).to(torch.int).item() logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) - logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False) + logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, + enforce_sorted=False) outputs, (_, _) = self.postnet(logits) outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] @@ -83,95 +108,51 @@ class DiarEENDOLAModel(FunASRModel): def forward( self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - text: torch.Tensor, - text_lengths: torch.Tensor, + speech: List[torch.Tensor], + speech_lengths: torch.Tensor, # num_frames of each sample + speaker_labels: List[torch.Tensor], + speaker_labels_lengths: torch.Tensor, # num_speakers of each sample + orders: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - """Frontend + Encoder + Decoder + Calc loss - Args: - speech: (Batch, Length, ...) - speech_lengths: (Batch, ) - text: (Batch, Length) - text_lengths: (Batch,) - """ - assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified assert ( - speech.shape[0] - == speech_lengths.shape[0] - == text.shape[0] - == text_lengths.shape[0] - ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) - batch_size = speech.shape[0] + len(speech) + == len(speech_lengths) + == len(speaker_labels) + == len(speaker_labels_lengths) + ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths)) + batch_size = len(speech) - # for data-parallel - text = text[:, : text_lengths.max()] + # Encoder + speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] + encoder_out = self.forward_encoder(speech, speech_lengths) - # 1. Encoder - encoder_out, encoder_out_lens = self.enc(speech, speech_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] + # Encoder-decoder attractor + attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)], + speaker_labels_lengths) + speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)] + + # pit loss + pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels) + pit_loss = standard_loss(speaker_logits, pit_speaker_labels) + + # pse loss + with torch.no_grad(): + power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker). + to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels] + pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors] + pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)] + pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths) + pse_loss = cal_power_loss(pse_speaker_logits, power_ts) + + loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss - loss_att, acc_att, cer_att, wer_att = None, None, None, None - loss_ctc, cer_ctc = None, None stats = dict() - - # 1. CTC branch - if self.ctc_weight != 0.0: - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc - - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - loss_ic, cer_ic = self._calc_ctc_loss( - intermediate_out, encoder_out_lens, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - - # 2b. Attention decoder branch - if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att - elif self.ctc_weight == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att - - # Collect Attn branch stats - stats["loss_att"] = loss_att.detach() if loss_att is not None else None - stats["acc"] = acc_att - stats["cer"] = cer_att - stats["wer"] = wer_att + stats["pse_loss"] = pse_loss.detach() + stats["pit_loss"] = pit_loss.detach() + stats["attractor_loss"] = attractor_loss.detach() + stats["batch_size"] = batch_size # Collect total loss stats stats["loss"] = torch.clone(loss.detach()) @@ -193,10 +174,10 @@ class DiarEENDOLAModel(FunASRModel): orders = [np.arange(e.shape[0]) for e in emb] for order in orders: np.random.shuffle(order) - attractors, probs = self.eda.estimate( + attractors, probs = self.encoder_decoder_attractor.estimate( [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) else: - attractors, probs = self.eda.estimate(emb) + attractors, probs = self.encoder_decoder_attractor.estimate(emb) attractors_active = [] for p, att, e in zip(probs, attractors, emb): if n_speakers and n_speakers >= 0: diff --git a/funasr/modules/eend_ola/eend_ola_dataloader.py b/funasr/modules/eend_ola/eend_ola_dataloader.py new file mode 100644 index 000000000..2ee9272f5 --- /dev/null +++ b/funasr/modules/eend_ola/eend_ola_dataloader.py @@ -0,0 +1,57 @@ +import logging + +import kaldiio +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + + +def custom_collate(batch): + keys, speech, speaker_labels, orders = zip(*batch) + speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] + speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels] + orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] + batch = dict(speech=speech, + speaker_labels=speaker_labels, + orders=orders) + + return keys, batch + + +class EENDOLADataset(Dataset): + def __init__( + self, + data_file, + ): + self.data_file = data_file + with open(data_file) as f: + lines = f.readlines() + self.samples = [line.strip().split() for line in lines] + logging.info("total samples: {}".format(len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + key, speech_path, speaker_label_path = self.samples[idx] + speech = kaldiio.load_mat(speech_path) + speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1) + + order = np.arange(speech.shape[0]) + np.random.shuffle(order) + + return key, speech, speaker_label, order + + +class EENDOLADataLoader(): + def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): + dataset = EENDOLADataset(data_file) + self.data_loader = DataLoader(dataset, + batch_size=batch_size, + collate_fn=custom_collate, + shuffle=shuffle, + num_workers=num_workers) + + def build_iter(self, epoch): + return self.data_loader \ No newline at end of file diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py index 90a63f369..3065884cb 100644 --- a/funasr/modules/eend_ola/encoder.py +++ b/funasr/modules/eend_ola/encoder.py @@ -91,6 +91,7 @@ class EENDOLATransformerEncoder(nn.Module): dropout_rate: float = 0.1, use_pos_emb: bool = False): super(EENDOLATransformerEncoder, self).__init__() + self.linear_in = nn.Linear(idim, n_units) self.lnorm_in = nn.LayerNorm(n_units) self.n_layers = n_layers self.dropout = nn.Dropout(dropout_rate) @@ -104,25 +105,10 @@ class EENDOLATransformerEncoder(nn.Module): setattr(self, '{}{:d}'.format("ff_", i), PositionwiseFeedForward(n_units, e_units, dropout_rate)) self.lnorm_out = nn.LayerNorm(n_units) - if use_pos_emb: - self.pos_enc = torch.nn.Sequential( - torch.nn.Linear(idim, n_units), - torch.nn.LayerNorm(n_units), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - PositionalEncoding(n_units, dropout_rate), - ) - else: - self.linear_in = nn.Linear(idim, n_units) - self.pos_enc = None def __call__(self, x, x_mask=None): BT_size = x.shape[0] * x.shape[1] - if self.pos_enc is not None: - e = self.pos_enc(x) - e = e.view(BT_size, -1) - else: - e = self.linear_in(x.reshape(BT_size, -1)) + e = self.linear_in(x.reshape(BT_size, -1)) for i in range(self.n_layers): e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask) @@ -130,4 +116,4 @@ class EENDOLATransformerEncoder(nn.Module): e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) s = getattr(self, '{}{:d}'.format("ff_", i))(e) e = e + self.dropout(s) - return self.lnorm_out(e) + return self.lnorm_out(e) \ No newline at end of file diff --git a/funasr/modules/eend_ola/utils/losses.py b/funasr/modules/eend_ola/utils/losses.py index af0181dda..756952d03 100644 --- a/funasr/modules/eend_ola/utils/losses.py +++ b/funasr/modules/eend_ola/utils/losses.py @@ -1,11 +1,10 @@ import numpy as np import torch import torch.nn.functional as F -from itertools import permutations -from torch import nn +from scipy.optimize import linear_sum_assignment -def standard_loss(ys, ts, label_delay=0): +def standard_loss(ys, ts): losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)] loss = torch.sum(torch.stack(losses)) n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device) @@ -13,55 +12,29 @@ def standard_loss(ys, ts, label_delay=0): return loss -def batch_pit_n_speaker_loss(ys, ts, n_speakers_list): - max_n_speakers = ts[0].shape[1] - olens = [y.shape[0] for y in ys] - ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1) - ys_mask = [torch.ones(olen).to(ys.device) for olen in olens] - ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1) +def fast_batch_pit_n_speaker_loss(ys, ts): + with torch.no_grad(): + bs = len(ys) + indices = [] + for b in range(bs): + y = ys[b].transpose(0, 1) + t = ts[b].transpose(0, 1) + C, _ = t.shape + y = y[:, None, :].repeat(1, C, 1) + t = t[None, :, :].repeat(C, 1, 1) + bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1) + C = bce_loss.cpu() + indices.append(linear_sum_assignment(C)) + labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)] - losses = [] - for shift in range(max_n_speakers): - ts_roll = [torch.roll(t, -shift, dims=1) for t in ts] - ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1) - loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none') - if ys_mask is not None: - loss = loss * ys_mask - loss = torch.sum(loss, dim=1) - losses.append(loss) - losses = torch.stack(losses, dim=2) + return labels_perm - perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32) - perms = torch.from_numpy(perms).to(losses.device) - y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device) - t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long) - losses_perm = [] - for t_ind in t_inds: - losses_perm.append( - torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1)) - losses_perm = torch.stack(losses_perm, dim=1) - - def select_perm_indices(num, max_num): - perms = list(permutations(range(max_num))) - sub_perms = list(permutations(range(num))) - return [ - [x[:num] for x in perms].index(perm) - for perm in sub_perms] - - masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf')) - for i, t in enumerate(ts): - n_speakers = n_speakers_list[i] - indices = select_perm_indices(n_speakers, max_n_speakers) - masks[i, indices] = 0 - losses_perm += masks - - min_loss = torch.sum(torch.min(losses_perm, dim=1)[0]) - n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device) - min_loss = min_loss / n_frames - - min_indices = torch.argmin(losses_perm, dim=1) - labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)] - labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)] - - return min_loss, labels_perm +def cal_power_loss(logits, power_ts): + losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in + zip(logits, power_ts)] + loss = torch.sum(torch.stack(losses)) + n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to( + power_ts[0].device) + loss = loss / n_frames + return loss From 8b7c32b0f6616e47b5ac5037d4231cc5d77e74ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 6 Jul 2023 16:28:31 +0800 Subject: [PATCH 02/42] update eend_ola --- funasr/utils/prepare_data.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py index 0e773bbed..8d82a2fc8 100644 --- a/funasr/utils/prepare_data.py +++ b/funasr/utils/prepare_data.py @@ -196,12 +196,16 @@ def generate_data_list(args, data_dir, dataset, nj=64): def prepare_data(args, distributed_option): distributed = distributed_option.distributed + data_names = args.dataset_conf.get("data_names", "speech,text").split(",") + data_types = args.dataset_conf.get("data_types", "sound,text").split(",") + file_names = args.data_file_names.split(",") + batch_type = args.dataset_conf["batch_conf"]["batch_type"] if not distributed or distributed_option.dist_rank == 0: if hasattr(args, "filter_input") and args.filter_input: filter_wav_text(args.data_dir, args.train_set) filter_wav_text(args.data_dir, args.valid_set) - if args.dataset_type == "small": + if args.dataset_type == "small" and batch_type != "unsorted": calc_shape(args, args.train_set) calc_shape(args, args.valid_set) @@ -209,9 +213,6 @@ def prepare_data(args, distributed_option): generate_data_list(args, args.data_dir, args.train_set) generate_data_list(args, args.data_dir, args.valid_set) - data_names = args.dataset_conf.get("data_names", "speech,text").split(",") - data_types = args.dataset_conf.get("data_types", "sound,text").split(",") - file_names = args.data_file_names.split(",") print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names)) assert len(data_names) == len(data_types) == len(file_names) if args.dataset_type == "small": From c83c406b72623deb973d391635475c5dfd9f8b93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 6 Jul 2023 17:12:19 +0800 Subject: [PATCH 03/42] update eend_ola --- ...rain_diar_eend_ola_callhome_chunk2000.yaml | 45 ++++ .../conf/train_diar_eend_ola_simu_2spkr.yaml | 52 ++++ .../train_diar_eend_ola_simu_allspkr.yaml | 52 ++++ ..._diar_eend_ola_simu_allspkr_chunk2000.yaml | 44 ++++ .../eend_ola/local/model_averaging.py | 28 ++ egs/callhome/eend_ola/path.sh | 6 + egs/callhome/eend_ola/run.sh | 242 ++++++++++++++++++ egs/callhome/eend_ola/utils | 1 + 8 files changed, 470 insertions(+) create mode 100644 egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml create mode 100644 egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml create mode 100644 egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml create mode 100644 egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml create mode 100644 egs/callhome/eend_ola/local/model_averaging.py create mode 100755 egs/callhome/eend_ola/path.sh create mode 100644 egs/callhome/eend_ola/run.sh create mode 120000 egs/callhome/eend_ola/utils diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml new file mode 100644 index 000000000..71ea9f0e9 --- /dev/null +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml @@ -0,0 +1,45 @@ +# network architecture +# encoder related +encoder: eend_ola_transformer +encoder_conf: + idim: 345 + n_layers: 4 + n_units: 256 + +# encoder-decoder attractor related +encoder_decoder_attractor: eda +encoder_decoder_attractor_conf: + n_units: 256 + +# model related +model: eend_ola_similar_eend +model_conf: + attractor_loss_weight: 0.01 + max_n_speaker: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 100 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 100 + +optim: adam +optim_conf: + lr: 0.00001 + +dataset_conf: + data_names: speech_speaker_labels + data_types: kaldi_ark + batch_conf: + batch_type: unsorted + batch_size: 8 + num_workers: 8 + +log_interval: 50 \ No newline at end of file diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml new file mode 100644 index 000000000..baf43424f --- /dev/null +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml @@ -0,0 +1,52 @@ +# network architecture +# encoder related +encoder: eend_ola_transformer +encoder_conf: + idim: 345 + n_layers: 4 + n_units: 256 + +# encoder-decoder attractor related +encoder_decoder_attractor: eda +encoder_decoder_attractor_conf: + n_units: 256 + +# model related +model: eend_ola_similar_eend +model_conf: + max_n_speaker: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 100 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 100 + +optim: adam +optim_conf: + lr: 1.0 + betas: + - 0.9 + - 0.98 + eps: 1.0e-9 +scheduler: noamlr +scheduler_conf: + model_size: 256 + warmup_steps: 100000 + +dataset_conf: + data_names: speech_speaker_labels + data_types: kaldi_ark + batch_conf: + batch_type: unsorted + batch_size: 64 + num_workers: 8 + +log_interval: 50 \ No newline at end of file diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml new file mode 100644 index 000000000..83a6eeeb9 --- /dev/null +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml @@ -0,0 +1,52 @@ +# network architecture +# encoder related +encoder: eend_ola_transformer +encoder_conf: + idim: 345 + n_layers: 4 + n_units: 256 + +# encoder-decoder attractor related +encoder_decoder_attractor: eda +encoder_decoder_attractor_conf: + n_units: 256 + +# model related +model: eend_ola_similar_eend +model_conf: + max_n_speaker: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 25 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 100 + +optim: adam +optim_conf: + lr: 1.0 + betas: + - 0.9 + - 0.98 + eps: 1.0e-9 +scheduler: noamlr +scheduler_conf: + model_size: 256 + warmup_steps: 100000 + +dataset_conf: + data_names: speech_speaker_labels + data_types: kaldi_ark + batch_conf: + batch_type: unsorted + batch_size: 64 + num_workers: 8 + +log_interval: 50 \ No newline at end of file diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml new file mode 100644 index 000000000..f47850417 --- /dev/null +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml @@ -0,0 +1,44 @@ +# network architecture +# encoder related +encoder: eend_ola_transformer +encoder_conf: + idim: 345 + n_layers: 4 + n_units: 256 + +# encoder-decoder attractor related +encoder_decoder_attractor: eda +encoder_decoder_attractor_conf: + n_units: 256 + +# model related +model: eend_ola_similar_eend +model_conf: + max_n_speaker: 8 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 1 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 100 + +optim: adam +optim_conf: + lr: 0.00001 + +dataset_conf: + data_names: speech_speaker_labels + data_types: kaldi_ark + batch_conf: + batch_type: unsorted + batch_size: 8 + num_workers: 8 + +log_interval: 50 \ No newline at end of file diff --git a/egs/callhome/eend_ola/local/model_averaging.py b/egs/callhome/eend_ola/local/model_averaging.py new file mode 100644 index 000000000..1871cd9cb --- /dev/null +++ b/egs/callhome/eend_ola/local/model_averaging.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import argparse + +import torch + + +def average_model(input_files, output_file): + output_model = {} + for ckpt_path in input_files: + model_params = torch.load(ckpt_path, map_location="cpu") + for key, value in model_params.items(): + if key not in output_model: + output_model[key] = value + else: + output_model[key] += value + for key in output_model.keys(): + output_model[key] /= len(input_files) + torch.save(output_model, output_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("output_file") + parser.add_argument("input_files", nargs='+') + args = parser.parse_args() + + average_model(args.input_files, args.output_file) \ No newline at end of file diff --git a/egs/callhome/eend_ola/path.sh b/egs/callhome/eend_ola/path.sh new file mode 100755 index 000000000..ea3c0be2f --- /dev/null +++ b/egs/callhome/eend_ola/path.sh @@ -0,0 +1,6 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh new file mode 100644 index 000000000..893613752 --- /dev/null +++ b/egs/callhome/eend_ola/run.sh @@ -0,0 +1,242 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="7" +gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +count=1 + +# general configuration +simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" +simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" +callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" +simu_train_dataset=train +simu_valid_dataset=dev +callhome_train_dataset=callhome1_allspk +callhome_valid_dataset=callhome2_allspk +callhome2_wav_scp_file=wav.scp + +# model average +simu_average_2spkr_start=91 +simu_average_2spkr_end=100 +simu_average_allspkr_start=16 +simu_average_allspkr_end=25 +callhome_average_start=91 +callhome_average_end=100 + +exp_dir="." +input_size=345 +stage=1 +stop_stage=4 + +# exp tag +tag="exp_fix" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +simu_2spkr_diar_config=conf/train_diar_eend_ola_simu_2spkr.yaml +simu_allspkr_diar_config=conf/train_diar_eend_ola_simu_allspkr.yaml +simu_allspkr_chunk2000_diar_config=conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml +callhome_diar_config=conf/train_diar_eend_ola_callhome_chunk2000.yaml +simu_2spkr_model_dir="baseline_$(basename "${simu_2spkr_diar_config}" .yaml)_${tag}" +simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml)_${tag}" +simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}" +callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" + +# Prepare data for training and inference +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Prepare data for training and inference" +fi + +# Training on simulated two-speaker data +world_size=$gpu_num +simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end} +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Training on simulated two-speaker data" + mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir} + mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_2spkr_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats_2spkr.scp" \ + --resume true \ + --output_dir ${exp_dir}/exp/${simu_2spkr_model_dir} \ + --config $simu_2spkr_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_2spkr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$simu_2spkr_model_dir/$simu_2spkr_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${simu_2spkr_model_dir}/{$simu_average_2spkr_start..$simu_average_2spkr_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb $models +fi + +# Training on simulated all-speaker data +world_size=$gpu_num +simu_allspkr_ave_id=avg${simu_average_allspkr_start}-${simu_average_allspkr_end} +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Training on simulated all-speaker data" + mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir} + mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_allspkr_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb \ + --output_dir ${exp_dir}/exp/${simu_allspkr_model_dir} \ + --config $simu_allspkr_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$simu_allspkr_model_dir/$simu_allspkr_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${simu_allspkr_model_dir}/{$simu_average_allspkr_start..$simu_average_allspkr_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models +fi + +# Training on simulated all-speaker data with chunk_size=2000 +world_size=$gpu_num +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Training on simulated all-speaker data with chunk_size=2000" + mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} + mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir_chunk2000} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb \ + --output_dir ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} \ + --config $simu_allspkr_chunk2000_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Training on callhome all-speaker data with chunk_size=2000 +world_size=$gpu_num +callhome_ave_id=avg${callhome_average_start}-${callhome_average_end} +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: Training on callhome all-speaker data with chunk_size=2000" + mkdir -p ${exp_dir}/exp/${callhome_model_dir} + mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log + INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${callhome_feats_dir_chunk2000} \ + --train_set ${callhome_train_dataset} \ + --valid_set ${callhome_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/1epoch.pb \ + --output_dir ${exp_dir}/exp/${callhome_model_dir} \ + --config $callhome_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${callhome_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$callhome_model_dir/$callhome_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${callhome_model_dir}/{$callhome_average_start..$callhome_average_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models +fi + +## inference +#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then +# echo "Inference" +# mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log +# CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \ +# --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ +# --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ +# --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ +# --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 +#fi \ No newline at end of file diff --git a/egs/callhome/eend_ola/utils b/egs/callhome/eend_ola/utils new file mode 120000 index 000000000..fe070dd3a --- /dev/null +++ b/egs/callhome/eend_ola/utils @@ -0,0 +1 @@ +../../aishell/transformer/utils \ No newline at end of file From c46a271415d312de777ffbfe3d735e44e39ff68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 6 Jul 2023 17:46:27 +0800 Subject: [PATCH 04/42] update eend_ola --- funasr/build_utils/build_diar_model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py index 444636a02..2969fd2bf 100644 --- a/funasr/build_utils/build_diar_model.py +++ b/funasr/build_utils/build_diar_model.py @@ -178,18 +178,22 @@ class_choices_list = [ def build_diar_model(args): # token_list - if isinstance(args.token_list, str): - with open(args.token_list, encoding="utf-8") as f: - token_list = [line.rstrip() for line in f] + if args.token_list is not None: + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] - # Overwriting token_list to keep it as "portable". - args.token_list = list(token_list) - elif isinstance(args.token_list, (tuple, list)): - token_list = list(args.token_list) + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size}") else: - raise RuntimeError("token_list must be str or list") - vocab_size = len(token_list) - logging.info(f"Vocabulary size: {vocab_size}") + token_list = None + vocab_size = None # frontend if args.input_size is None: From 91425c670b21fa244f739885d34b88742272747c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 6 Jul 2023 17:54:29 +0800 Subject: [PATCH 05/42] update eend_ola --- .../train_diar_eend_ola_callhome_chunk2000.yaml | 2 +- .../conf/train_diar_eend_ola_simu_2spkr.yaml | 2 +- .../conf/train_diar_eend_ola_simu_allspkr.yaml | 2 +- ...train_diar_eend_ola_simu_allspkr_chunk2000.yaml | 2 +- funasr/models/e2e_diar_eend_ola.py | 14 ++++---------- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml index 71ea9f0e9..cd143f721 100644 --- a/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_callhome_chunk2000.yaml @@ -12,7 +12,7 @@ encoder_decoder_attractor_conf: n_units: 256 # model related -model: eend_ola_similar_eend +model: eend_ola model_conf: attractor_loss_weight: 0.01 max_n_speaker: 8 diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml index baf43424f..47316fe36 100644 --- a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_2spkr.yaml @@ -12,7 +12,7 @@ encoder_decoder_attractor_conf: n_units: 256 # model related -model: eend_ola_similar_eend +model: eend_ola model_conf: max_n_speaker: 8 diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml index 83a6eeeb9..f55e14895 100644 --- a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr.yaml @@ -12,7 +12,7 @@ encoder_decoder_attractor_conf: n_units: 256 # model related -model: eend_ola_similar_eend +model: eend_ola model_conf: max_n_speaker: 8 diff --git a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml index f47850417..d21d467a1 100644 --- a/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml +++ b/egs/callhome/eend_ola/conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml @@ -12,7 +12,7 @@ encoder_decoder_attractor_conf: n_units: 256 # model related -model: eend_ola_similar_eend +model: eend_ola model_conf: max_n_speaker: 8 diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index af0fd62c8..fda24e227 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -12,7 +12,7 @@ from funasr.models.base_model import FunASRModel from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor -from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss +from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss from funasr.modules.eend_ola.utils.power import create_powerlabel from funasr.modules.eend_ola.utils.power import generate_mapping_dict from funasr.torch_utils.device_funcs import force_gatherable @@ -109,23 +109,17 @@ class DiarEENDOLAModel(FunASRModel): def forward( self, speech: List[torch.Tensor], - speech_lengths: torch.Tensor, # num_frames of each sample speaker_labels: List[torch.Tensor], - speaker_labels_lengths: torch.Tensor, # num_speakers of each sample orders: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # Check that batch_size is unified - assert ( - len(speech) - == len(speech_lengths) - == len(speaker_labels) - == len(speaker_labels_lengths) - ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths)) + assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels)) + speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64) + speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64) batch_size = len(speech) # Encoder - speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] encoder_out = self.forward_encoder(speech, speech_lengths) # Encoder-decoder attractor From 6494a503f4ce11634cfd42d562011541b1e4ebf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 14:55:24 +0800 Subject: [PATCH 06/42] update --- egs/callhome/eend_ola/local/make_callhome.sh | 73 ++++++ egs/callhome/eend_ola/local/make_mixture.py | 120 +++++++++ egs/callhome/eend_ola/local/make_musan.py | 123 +++++++++ egs/callhome/eend_ola/local/make_musan.sh | 37 +++ egs/callhome/eend_ola/local/make_sre.pl | 63 +++++ egs/callhome/eend_ola/local/make_sre.sh | 48 ++++ .../eend_ola/local/make_swbd2_phase1.pl | 106 ++++++++ .../eend_ola/local/make_swbd2_phase2.pl | 107 ++++++++ .../eend_ola/local/make_swbd2_phase3.pl | 102 ++++++++ .../eend_ola/local/make_swbd_cellular1.pl | 83 +++++++ .../eend_ola/local/make_swbd_cellular2.pl | 83 +++++++ egs/callhome/eend_ola/local/random_mixture.py | 145 +++++++++++ egs/callhome/eend_ola/local/run_blstm.sh | 9 + .../eend_ola/local/run_prepare_shared_eda.sh | 235 ++++++++++++++++++ egs/callhome/eend_ola/path.sh | 7 + egs/callhome/eend_ola/run.sh | 25 +- egs/callhome/{diarization => }/sond/sond.yaml | 0 .../{diarization => }/sond/sond_fbank.yaml | 0 .../{diarization => }/sond/unit_test.py | 0 19 files changed, 1361 insertions(+), 5 deletions(-) create mode 100644 egs/callhome/eend_ola/local/make_callhome.sh create mode 100644 egs/callhome/eend_ola/local/make_mixture.py create mode 100644 egs/callhome/eend_ola/local/make_musan.py create mode 100644 egs/callhome/eend_ola/local/make_musan.sh create mode 100644 egs/callhome/eend_ola/local/make_sre.pl create mode 100644 egs/callhome/eend_ola/local/make_sre.sh create mode 100644 egs/callhome/eend_ola/local/make_swbd2_phase1.pl create mode 100644 egs/callhome/eend_ola/local/make_swbd2_phase2.pl create mode 100644 egs/callhome/eend_ola/local/make_swbd2_phase3.pl create mode 100644 egs/callhome/eend_ola/local/make_swbd_cellular1.pl create mode 100644 egs/callhome/eend_ola/local/make_swbd_cellular2.pl create mode 100644 egs/callhome/eend_ola/local/random_mixture.py create mode 100644 egs/callhome/eend_ola/local/run_blstm.sh create mode 100644 egs/callhome/eend_ola/local/run_prepare_shared_eda.sh rename egs/callhome/{diarization => }/sond/sond.yaml (100%) rename egs/callhome/{diarization => }/sond/sond_fbank.yaml (100%) rename egs/callhome/{diarization => }/sond/unit_test.py (100%) diff --git a/egs/callhome/eend_ola/local/make_callhome.sh b/egs/callhome/eend_ola/local/make_callhome.sh new file mode 100644 index 000000000..caa8f679f --- /dev/null +++ b/egs/callhome/eend_ola/local/make_callhome.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Copyright 2017 David Snyder +# Apache 2.0. +# +# This script prepares the Callhome portion of the NIST SRE 2000 +# corpus (LDC2001S97). It is the evaluation dataset used in the +# callhome_diarization recipe. + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /mnt/data/LDC2001S97 data/" + exit 1; +fi + +src_dir=$1 +data_dir=$2 + +tmp_dir=$data_dir/callhome/.tmp/ +mkdir -p $tmp_dir + +# Download some metadata that wasn't provided in the LDC release +if [ ! -d "$tmp_dir/sre2000-key" ]; then + wget --no-check-certificate -P $tmp_dir/ \ + http://www.openslr.org/resources/10/sre2000-key.tar.gz + tar -xvf $tmp_dir/sre2000-key.tar.gz -C $tmp_dir/ +fi + +# The list of 500 recordings +awk '{print $1}' $tmp_dir/sre2000-key/reco2num > $tmp_dir/reco.list + +# Create wav.scp file +count=0 +missing=0 +while read reco; do + path=$(find $src_dir -name "$reco.sph") + if [ -z "${path// }" ]; then + >&2 echo "$0: Missing Sphere file for $reco" + missing=$((missing+1)) + else + echo "$reco sph2pipe -f wav -p $path |" + fi + count=$((count+1)) +done < $tmp_dir/reco.list > $data_dir/callhome/wav.scp + +if [ $missing -gt 0 ]; then + echo "$0: Missing $missing out of $count recordings" +fi + +cp $tmp_dir/sre2000-key/segments $data_dir/callhome/ +awk '{print $1, $2}' $data_dir/callhome/segments > $data_dir/callhome/utt2spk +utils/utt2spk_to_spk2utt.pl $data_dir/callhome/utt2spk > $data_dir/callhome/spk2utt +cp $tmp_dir/sre2000-key/reco2num $data_dir/callhome/reco2num_spk +cp $tmp_dir/sre2000-key/fullref.rttm $data_dir/callhome/ + +utils/validate_data_dir.sh --no-text --no-feats $data_dir/callhome +utils/fix_data_dir.sh $data_dir/callhome + +utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome1 +utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome2 + +utils/shuffle_list.pl $data_dir/callhome/wav.scp | head -n 250 \ + | utils/filter_scp.pl - $data_dir/callhome/wav.scp \ + > $data_dir/callhome1/wav.scp +utils/fix_data_dir.sh $data_dir/callhome1 +utils/filter_scp.pl --exclude $data_dir/callhome1/wav.scp \ + $data_dir/callhome/wav.scp > $data_dir/callhome2/wav.scp +utils/fix_data_dir.sh $data_dir/callhome2 +utils/filter_scp.pl $data_dir/callhome1/wav.scp $data_dir/callhome/reco2num_spk \ + > $data_dir/callhome1/reco2num_spk +utils/filter_scp.pl $data_dir/callhome2/wav.scp $data_dir/callhome/reco2num_spk \ + > $data_dir/callhome2/reco2num_spk + +rm -rf $tmp_dir 2> /dev/null diff --git a/egs/callhome/eend_ola/local/make_mixture.py b/egs/callhome/eend_ola/local/make_mixture.py new file mode 100644 index 000000000..82d03cd60 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_mixture.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This script generates simulated multi-talker mixtures for diarization +# +# common/make_mixture.py \ +# mixture.scp \ +# data/mixture \ +# wav/mixture + + +import argparse +import os +from eend import kaldi_data +import numpy as np +import math +import soundfile as sf +import json + +parser = argparse.ArgumentParser() +parser.add_argument('script', + help='list of json') +parser.add_argument('out_data_dir', + help='output data dir of mixture') +parser.add_argument('out_wav_dir', + help='output mixture wav files are stored here') +parser.add_argument('--rate', type=int, default=16000, + help='sampling rate') +args = parser.parse_args() + +# open output data files +segments_f = open(args.out_data_dir + '/segments', 'w') +utt2spk_f = open(args.out_data_dir + '/utt2spk', 'w') +wav_scp_f = open(args.out_data_dir + '/wav.scp', 'w') + +# "-R" forces the default random seed for reproducibility +resample_cmd = "sox -R -t wav - -t wav - rate {}".format(args.rate) + +for line in open(args.script): + recid, jsonstr = line.strip().split(None, 1) + indata = json.loads(jsonstr) + wavfn = indata['recid'] + # recid now include out_wav_dir + recid = os.path.join(args.out_wav_dir, wavfn).replace('/','_') + noise = indata['noise'] + noise_snr = indata['snr'] + mixture = [] + for speaker in indata['speakers']: + spkid = speaker['spkid'] + utts = speaker['utts'] + intervals = speaker['intervals'] + rir = speaker['rir'] + data = [] + pos = 0 + for interval, utt in zip(intervals, utts): + # append silence interval data + silence = np.zeros(int(interval * args.rate)) + data.append(silence) + # utterance is reverberated using room impulse response + preprocess = "wav-reverberate --print-args=false " \ + " --impulse-response={} - -".format(rir) + if isinstance(utt, list): + rec, st, et = utt + st = np.rint(st * args.rate).astype(int) + et = np.rint(et * args.rate).astype(int) + else: + rec = utt + st = 0 + et = None + if rir is not None: + wav_rxfilename = kaldi_data.process_wav(rec, preprocess) + else: + wav_rxfilename = rec + wav_rxfilename = kaldi_data.process_wav( + wav_rxfilename, resample_cmd) + speech, _ = kaldi_data.load_wav(wav_rxfilename, st, et) + data.append(speech) + # calculate start/end position in samples + startpos = pos + len(silence) + endpos = startpos + len(speech) + # write segments and utt2spk + uttid = '{}_{}_{:07d}_{:07d}'.format( + spkid, recid, int(startpos / args.rate * 100), + int(endpos / args.rate * 100)) + print(uttid, recid, + startpos / args.rate, endpos / args.rate, file=segments_f) + print(uttid, spkid, file=utt2spk_f) + # update position for next utterance + pos = endpos + data = np.concatenate(data) + mixture.append(data) + + # fitting to the maximum-length speaker data, then mix all speakers + maxlen = max(len(x) for x in mixture) + mixture = [np.pad(x, (0, maxlen - len(x)), 'constant') for x in mixture] + mixture = np.sum(mixture, axis=0) + # noise is repeated or cutted for fitting to the mixture data length + noise_resampled = kaldi_data.process_wav(noise, resample_cmd) + noise_data, _ = kaldi_data.load_wav(noise_resampled) + if maxlen > len(noise_data): + noise_data = np.pad(noise_data, (0, maxlen - len(noise_data)), 'wrap') + else: + noise_data = noise_data[:maxlen] + # noise power is scaled according to selected SNR, then mixed + signal_power = np.sum(mixture**2) / len(mixture) + noise_power = np.sum(noise_data**2) / len(noise_data) + scale = math.sqrt( + math.pow(10, - noise_snr / 10) * signal_power / noise_power) + mixture += noise_data * scale + # output the wav file and write wav.scp + outfname = '{}.wav'.format(wavfn) + outpath = os.path.join(args.out_wav_dir, outfname) + sf.write(outpath, mixture, args.rate) + print(recid, os.path.abspath(outpath), file=wav_scp_f) + +wav_scp_f.close() +segments_f.close() +utt2spk_f.close() diff --git a/egs/callhome/eend_ola/local/make_musan.py b/egs/callhome/eend_ola/local/make_musan.py new file mode 100644 index 000000000..833da0619 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_musan.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2015 David Snyder +# 2018 Ewald Enzinger +# Apache 2.0. +# +# Modified version of egs/sre16/v1/local/make_musan.py (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8). +# This version uses the raw MUSAN audio files (16 kHz) and does not use sox to resample at 8 kHz. +# +# This file is meant to be invoked by make_musan.sh. + +import os, sys + +def process_music_annotations(path): + utt2spk = {} + utt2vocals = {} + lines = open(path, 'r').readlines() + for line in lines: + utt, genres, vocals, musician = line.rstrip().split()[:4] + # For this application, the musican ID isn't important + utt2spk[utt] = utt + utt2vocals[utt] = vocals == "Y" + return utt2spk, utt2vocals + +def prepare_music(root_dir, use_vocals): + utt2vocals = {} + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + music_dir = os.path.join(root_dir, "music") + for root, dirs, files in os.walk(music_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + elif str(file) == "ANNOTATIONS": + utt2spk_part, utt2vocals_part = process_music_annotations(file_path) + utt2spk.update(utt2spk_part) + utt2vocals.update(utt2vocals_part) + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2vocals: + if utt in utt2wav: + if use_vocals or not utt2vocals[utt]: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def prepare_speech(root_dir): + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + speech_dir = os.path.join(root_dir, "speech") + for root, dirs, files in os.walk(speech_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + utt2spk[utt] = utt + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2spk: + if utt in utt2wav: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def prepare_noise(root_dir): + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + noise_dir = os.path.join(root_dir, "noise") + for root, dirs, files in os.walk(noise_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + utt2spk[utt] = utt + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2spk: + if utt in utt2wav: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def main(): + in_dir = sys.argv[1] + out_dir = sys.argv[2] + use_vocals = sys.argv[3] == "Y" + utt2spk_music, utt2wav_music = prepare_music(in_dir, use_vocals) + utt2spk_speech, utt2wav_speech = prepare_speech(in_dir) + utt2spk_noise, utt2wav_noise = prepare_noise(in_dir) + utt2spk = utt2spk_speech + utt2spk_music + utt2spk_noise + utt2wav = utt2wav_speech + utt2wav_music + utt2wav_noise + wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w') + wav_fi.write(utt2wav) + utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w') + utt2spk_fi.write(utt2spk) + + +if __name__=="__main__": + main() diff --git a/egs/callhome/eend_ola/local/make_musan.sh b/egs/callhome/eend_ola/local/make_musan.sh new file mode 100644 index 000000000..694940ad7 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_musan.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2015 David Snyder +# Apache 2.0. +# +# This script, called by ../run.sh, creates the MUSAN +# data directory. The required dataset is freely available at +# http://www.openslr.org/17/ + +set -e +in_dir=$1 +data_dir=$2 +use_vocals='Y' + +mkdir -p local/musan.tmp + +echo "Preparing ${data_dir}/musan..." +mkdir -p ${data_dir}/musan +local/make_musan.py ${in_dir} ${data_dir}/musan ${use_vocals} + +utils/fix_data_dir.sh ${data_dir}/musan + +grep "music" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_music +grep "speech" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_speech +grep "noise" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_noise +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_music \ + ${data_dir}/musan ${data_dir}/musan_music +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_speech \ + ${data_dir}/musan ${data_dir}/musan_speech +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_noise \ + ${data_dir}/musan ${data_dir}/musan_noise + +utils/fix_data_dir.sh ${data_dir}/musan_music +utils/fix_data_dir.sh ${data_dir}/musan_speech +utils/fix_data_dir.sh ${data_dir}/musan_noise + +rm -rf local/musan.tmp + diff --git a/egs/callhome/eend_ola/local/make_sre.pl b/egs/callhome/eend_ola/local/make_sre.pl new file mode 100644 index 000000000..b86fa7ee7 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_sre.pl @@ -0,0 +1,63 @@ +#!/usr/bin/perl +# +# Copyright 2015 David Snyder +# Apache 2.0. +# Usage: make_sre.pl + +if (@ARGV != 4) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora5/LDC/LDC2006S44 sre2004 sre_ref data/sre2004\n"; + exit(1); +} + +($db_base, $sre_name, $sre_ref_filename, $out_dir) = @ARGV; +%utt2sph = (); +%spk2gender = (); + +$tmp_dir = "$out_dir/tmp"; +if (system("mkdir -p $tmp_dir") != 0) { + die "Error making directory $tmp_dir"; +} + +if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { + die "Error getting list of sph files"; +} +open(WAVLIST, "<", "$tmp_dir/sph.list") or die "cannot open wav list"; + +while() { + chomp; + $sph = $_; + @A1 = split("/",$sph); + @A2 = split("[./]",$A1[$#A1]); + $uttId=$A2[0]; + $utt2sph{$uttId} = $sph; +} + +open(GNDR,">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender"; +open(SPKR,">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; +open(WAV,">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; +open(SRE_REF, "<", $sre_ref_filename) or die "Cannot open SRE reference."; +while () { + chomp; + ($speaker, $gender, $other_sre_name, $utt_id, $channel) = split(" ", $_); + $channel_num = "1"; + if ($channel eq "A") { + $channel_num = "1"; + } else { + $channel_num = "2"; + } + if (($other_sre_name eq $sre_name) and (exists $utt2sph{$utt_id})) { + $full_utt_id = "$speaker-$gender-$sre_name-$utt_id-$channel"; + $spk2gender{"$speaker-$gender"} = $gender; + print WAV "$full_utt_id"," sph2pipe -f wav -p -c $channel_num $utt2sph{$utt_id} |\n"; + print SPKR "$full_utt_id $speaker-$gender","\n"; + } +} +foreach $speaker (keys %spk2gender) { + print GNDR "$speaker $spk2gender{$speaker}\n"; +} + +close(GNDR) || die; +close(SPKR) || die; +close(WAV) || die; +close(SRE_REF) || die; diff --git a/egs/callhome/eend_ola/local/make_sre.sh b/egs/callhome/eend_ola/local/make_sre.sh new file mode 100644 index 000000000..bef4e06e6 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_sre.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright 2015 David Snyder +# Apache 2.0. +# +# See README.txt for more info on data required. + +set -e + +data_root=$1 +data_dir=$2 + +wget -P data/local/ http://www.openslr.org/resources/15/speaker_list.tgz +tar -C data/local/ -xvf data/local/speaker_list.tgz +sre_ref=data/local/speaker_list + +local/make_sre.pl $data_root/LDC2006S44/ \ + sre2004 $sre_ref $data_dir/sre2004 + +local/make_sre.pl $data_root/LDC2011S01 \ + sre2005 $sre_ref $data_dir/sre2005_train + +local/make_sre.pl $data_root/LDC2011S04 \ + sre2005 $sre_ref $data_dir/sre2005_test + +local/make_sre.pl $data_root/LDC2011S09 \ + sre2006 $sre_ref $data_dir/sre2006_train + +local/make_sre.pl $data_root/LDC2011S10 \ + sre2006 $sre_ref $data_dir/sre2006_test_1 + +local/make_sre.pl $data_root/LDC2012S01 \ + sre2006 $sre_ref $data_dir/sre2006_test_2 + +local/make_sre.pl $data_root/LDC2011S05 \ + sre2008 $sre_ref $data_dir/sre2008_train + +local/make_sre.pl $data_root/LDC2011S08 \ + sre2008 $sre_ref $data_dir/sre2008_test + +utils/combine_data.sh $data_dir/sre \ + $data_dir/sre2004 $data_dir/sre2005_train \ + $data_dir/sre2005_test $data_dir/sre2006_train \ + $data_dir/sre2006_test_1 $data_dir/sre2006_test_2 \ + $data_dir/sre2008_train $data_dir/sre2008_test + +utils/validate_data_dir.sh --no-text --no-feats $data_dir/sre +utils/fix_data_dir.sh $data_dir/sre +rm data/local/speaker_list.* diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase1.pl b/egs/callhome/eend_ola/local/make_swbd2_phase1.pl new file mode 100644 index 000000000..71b26b55d --- /dev/null +++ b/egs/callhome/eend_ola/local/make_swbd2_phase1.pl @@ -0,0 +1,106 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2017 David Snyder +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora3/LDC/LDC98S75 data/swbd2_phase1_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/doc/callstat.tbl") || die "Could not open $db_base/doc/callstat.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio = ("3", "4"); + +$tmp_dir = "$out_dir/tmp"; +if (system("mkdir -p $tmp_dir") != 0) { + die "Error making directory $tmp_dir"; +} + +if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { + die "Error getting list of sph files"; +} + +open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; + +%wavs = (); +while() { + chomp; + $sph = $_; + @t = split("/",$sph); + @t1 = split("[./]",$t[$#t]); + $uttId = $t1[0]; + $wavs{$uttId} = $sph; +} + +while () { + $line = $_ ; + @A = split(",", $line); + @A1 = split("[./]",$A[0]); + $wav = $A1[0]; + if (/$wav/i ~~ @badAudio) { + # do nothing + print "Bad Audio = $wav"; + } else { + $spkr1= "sw_" . $A[2]; + $spkr2= "sw_" . $A[3]; + $gender1 = $A[5]; + $gender2 = $A[6]; + if ($gender1 eq "M") { + $gender1 = "m"; + } elsif ($gender1 eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($gender2 eq "M") { + $gender2 = "m"; + } elsif ($gender2 eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$wavs{$wav}") { + $uttId = $spkr1 ."_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wavs{$wav} |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wavs{$wav} |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $wavs{$wav} for $wav\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase2.pl b/egs/callhome/eend_ola/local/make_swbd2_phase2.pl new file mode 100644 index 000000000..337ab9d97 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_swbd2_phase2.pl @@ -0,0 +1,107 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2013 Daniel Povey +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora5/LDC/LDC99S79 data/swbd2_phase2_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/DISC1/doc/callstat.tbl") || die "Could not open $db_base/DISC1/doc/callstat.tbl"; +open(CI, "<$db_base/DISC1/doc/callinfo.tbl") || die "Could not open $db_base/DISC1/doc/callinfo.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio = ("3", "4"); + +$tmp_dir = "$out_dir/tmp"; +if (system("mkdir -p $tmp_dir") != 0) { + die "Error making directory $tmp_dir"; +} + +if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { + die "Error getting list of sph files"; +} + +open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; + +while() { + chomp; + $sph = $_; + @t = split("/",$sph); + @t1 = split("[./]",$t[$#t]); + $uttId=$t1[0]; + $wav{$uttId} = $sph; +} + +while () { + $line = $_ ; + $ci = ; + $ci = ; + @ci = split(",",$ci); + $wav = $ci[0]; + @A = split(",", $line); + if (/$wav/i ~~ @badAudio) { + # do nothing + } else { + $spkr1= "sw_" . $A[2]; + $spkr2= "sw_" . $A[3]; + $gender1 = $A[4]; + $gender2 = $A[5]; + if ($gender1 eq "M") { + $gender1 = "m"; + } elsif ($gender1 eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($gender2 eq "M") { + $gender2 = "m"; + } elsif ($gender2 eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$wav{$wav}") { + $uttId = $spkr1 ."_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $wav{$wav} for $wav\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase3.pl b/egs/callhome/eend_ola/local/make_swbd2_phase3.pl new file mode 100644 index 000000000..f27853415 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_swbd2_phase3.pl @@ -0,0 +1,102 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2013 Daniel Povey +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora5/LDC/LDC2002S06 data/swbd2_phase3_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/DISC1/docs/callstat.tbl") || die "Could not open $db_base/DISC1/docs/callstat.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio = ("3", "4"); + +$tmp_dir = "$out_dir/tmp"; +if (system("mkdir -p $tmp_dir") != 0) { + die "Error making directory $tmp_dir"; +} + +if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { + die "Error getting list of sph files"; +} + +open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; +while() { + chomp; + $sph = $_; + @t = split("/",$sph); + @t1 = split("[./]",$t[$#t]); + $uttId=$t1[0]; + $wav{$uttId} = $sph; +} + +while () { + $line = $_ ; + @A = split(",", $line); + $wav = "sw_" . $A[0] ; + if (/$wav/i ~~ @badAudio) { + # do nothing + } else { + $spkr1= "sw_" . $A[3]; + $spkr2= "sw_" . $A[4]; + $gender1 = $A[5]; + $gender2 = $A[6]; + if ($gender1 eq "M") { + $gender1 = "m"; + } elsif ($gender1 eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($gender2 eq "M") { + $gender2 = "m"; + } elsif ($gender2 eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$wav{$wav}") { + $uttId = $spkr1 ."_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $wav{$wav} for $wav\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl new file mode 100644 index 000000000..e30c710e6 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl @@ -0,0 +1,83 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2013 Daniel Povey +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora5/LDC/LDC2001S13 data/swbd_cellular1_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/doc/swb_callstats.tbl") || die "Could not open $db_base/doc/swb_callstats.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio = ("40019", "45024", "40022"); + +while () { + $line = $_ ; + @A = split(",", $line); + if (/$A[0]/i ~~ @badAudio) { + # do nothing + } else { + $wav = "sw_" . $A[0]; + $spkr1= "sw_" . $A[1]; + $spkr2= "sw_" . $A[2]; + $gender1 = $A[3]; + $gender2 = $A[4]; + if ($A[3] eq "M") { + $gender1 = "m"; + } elsif ($A[3] eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($A[4] eq "M") { + $gender2 = "m"; + } elsif ($A[4] eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$db_base/$wav.sph") { + $uttId = $spkr1 . "-swbdc_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/$wav.sph |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "-swbdc_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/$wav.sph |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $db_base/$wav.sph\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular2.pl b/egs/callhome/eend_ola/local/make_swbd_cellular2.pl new file mode 100644 index 000000000..4de954c19 --- /dev/null +++ b/egs/callhome/eend_ola/local/make_swbd_cellular2.pl @@ -0,0 +1,83 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2013 Daniel Povey +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora5/LDC/LDC2004S07 data/swbd_cellular2_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/docs/swb_callstats.tbl") || die "Could not open $db_base/docs/swb_callstats.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio=("45024", "40022"); + +while () { + $line = $_ ; + @A = split(",", $line); + if (/$A[0]/i ~~ @badAudio) { + # do nothing + } else { + $wav = "sw_" . $A[0]; + $spkr1= "sw_" . $A[1]; + $spkr2= "sw_" . $A[2]; + $gender1 = $A[3]; + $gender2 = $A[4]; + if ($A[3] eq "M") { + $gender1 = "m"; + } elsif ($A[3] eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($A[4] eq "M") { + $gender2 = "m"; + } elsif ($A[4] eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$db_base/data/$wav.sph") { + $uttId = $spkr1 . "-swbdc_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "-swbdc_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $db_base/data/$wav.sph\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome/eend_ola/local/random_mixture.py b/egs/callhome/eend_ola/local/random_mixture.py new file mode 100644 index 000000000..0032ef926 --- /dev/null +++ b/egs/callhome/eend_ola/local/random_mixture.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. + +""" +This script generates random multi-talker mixtures for diarization. +It generates a scp-like outputs: lines of "[recid] [json]". + recid: recording id of mixture + serial numbers like mix_0000001, mix_0000002, ... + json: mixture configuration formatted in "one-line" +The json format is as following: +{ + 'speakers':[ # list of speakers + { + 'spkid': 'Name', # speaker id + 'rir': '/rirdir/rir.wav', # wav_rxfilename of room impulse response + 'utts': [ # list of wav_rxfilenames of utterances + '/wavdir/utt1.wav', + '/wavdir/utt2.wav',...], + 'intervals': [1.2, 3.4, ...] # list of silence durations before utterances + }, ... ], + 'noise': '/noisedir/noise.wav' # wav_rxfilename of background noise + 'snr': 15.0, # SNR for mixing background noise + 'recid': 'mix_000001' # recording id of the mixture +} + +Usage: + common/random_mixture.py \ + --n_mixtures=10000 \ # number of mixtures + data/voxceleb1_train \ # kaldi-style data dir of utterances + data/musan_noise_bg \ # background noises + data/simu_rirs \ # room impulse responses + > mixture.scp # output scp-like file + +The actual data dir and wav files are generated using make_mixture.py: + common/make_mixture.py \ + mixture.scp \ # scp-like file for mixture + data/mixture \ # output data dir + wav/mixture # output wav dir +""" + +import argparse +import os +from eend import kaldi_data +import random +import numpy as np +import json +import itertools + +parser = argparse.ArgumentParser() +parser.add_argument('data_dir', + help='data dir of single-speaker recordings') +parser.add_argument('noise_dir', + help='data dir of background noise recordings') +parser.add_argument('rir_dir', + help='data dir of room impulse responses') +parser.add_argument('--n_mixtures', type=int, default=10, + help='number of mixture recordings') +parser.add_argument('--n_speakers', type=int, default=4, + help='number of speakers in a mixture') +parser.add_argument('--min_utts', type=int, default=10, + help='minimum number of uttenraces per speaker') +parser.add_argument('--max_utts', type=int, default=20, + help='maximum number of utterances per speaker') +parser.add_argument('--sil_scale', type=float, default=10.0, + help='average silence time') +parser.add_argument('--noise_snrs', default="10:15:20", + help='colon-delimited SNRs for background noises') +parser.add_argument('--random_seed', type=int, default=777, + help='random seed') +parser.add_argument('--speech_rvb_probability', type=float, default=1, + help='reverb probability') +args = parser.parse_args() + +random.seed(args.random_seed) +np.random.seed(args.random_seed) + +# load list of wav files from kaldi-style data dirs +wavs = kaldi_data.load_wav_scp( + os.path.join(args.data_dir, 'wav.scp')) +noises = kaldi_data.load_wav_scp( + os.path.join(args.noise_dir, 'wav.scp')) +rirs = kaldi_data.load_wav_scp( + os.path.join(args.rir_dir, 'wav.scp')) + +# spk2utt is used for counting number of utterances per speaker +spk2utt = kaldi_data.load_spk2utt( + os.path.join(args.data_dir, 'spk2utt')) + +segments = kaldi_data.load_segments_hash( + os.path.join(args.data_dir, 'segments')) + +# choice lists for random sampling +all_speakers = list(spk2utt.keys()) +all_noises = list(noises.keys()) +all_rirs = list(rirs.keys()) +noise_snrs = [float(x) for x in args.noise_snrs.split(':')] + +mixtures = [] +for it in range(args.n_mixtures): + # recording ids are mix_0000001, mix_0000002, ... + recid = 'mix_{:07d}'.format(it + 1) + # randomly select speakers, a background noise and a SNR + speakers = random.sample(all_speakers, args.n_speakers) + noise = random.choice(all_noises) + noise_snr = random.choice(noise_snrs) + mixture = {'speakers': []} + for speaker in speakers: + # randomly select the number of utterances + n_utts = np.random.randint(args.min_utts, args.max_utts + 1) + # utts = spk2utt[speaker][:n_utts] + cycle_utts = itertools.cycle(spk2utt[speaker]) + # random start utterance + roll = np.random.randint(0, len(spk2utt[speaker])) + for i in range(roll): + next(cycle_utts) + utts = [next(cycle_utts) for i in range(n_utts)] + # randomly select wait time before appending utterance + intervals = np.random.exponential(args.sil_scale, size=n_utts) + # randomly select a room impulse response + if random.random() < args.speech_rvb_probability: + rir = rirs[random.choice(all_rirs)] + else: + rir = None + if segments is not None: + utts = [segments[utt] for utt in utts] + utts = [(wavs[rec], st, et) for (rec, st, et) in utts] + mixture['speakers'].append({ + 'spkid': speaker, + 'rir': rir, + 'utts': utts, + 'intervals': intervals.tolist() + }) + else: + mixture['speakers'].append({ + 'spkid': speaker, + 'rir': rir, + 'utts': [wavs[utt] for utt in utts], + 'intervals': intervals.tolist() + }) + mixture['noise'] = noises[noise] + mixture['snr'] = noise_snr + mixture['recid'] = recid + print(recid, json.dumps(mixture)) diff --git a/egs/callhome/eend_ola/local/run_blstm.sh b/egs/callhome/eend_ola/local/run_blstm.sh new file mode 100644 index 000000000..71270a4a2 --- /dev/null +++ b/egs/callhome/eend_ola/local/run_blstm.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# BLSTM-based model experiment +./run.sh --train-config conf/blstm/train.yaml --average-start 20 --average-end 20 \ + --adapt-config conf/blstm/adapt.yaml --adapt-average-start 10 --adapt-average-end 10 \ + --infer-config conf/blstm/infer.yaml $* diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh new file mode 100644 index 000000000..f48adc54f --- /dev/null +++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh @@ -0,0 +1,235 @@ +#!/bin/bash + +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi) +# Licensed under the MIT license. +# +# This script prepares kaldi-style data sets shared with different experiments +# - data/xxxx +# callhome, sre, swb2, and swb_cellular datasets +# - data/simu_${simu_outputs} +# simulation mixtures generated with various options + +stage=0 + +# Modify corpus directories +# - callhome_dir +# CALLHOME (LDC2001S97) +# - swb2_phase1_train +# Switchboard-2 Phase 1 (LDC98S75) +# - data_root +# LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07, +# LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09, +# LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 +# - musan_root +# MUSAN corpus (https://www.openslr.org/17/) +callhome_dir=/export/corpora/NIST/LDC2001S97 +swb2_phase1_train=/export/corpora/LDC/LDC98S75 +data_root=/export/corpora5/LDC +musan_root=/export/corpora/JHU/musan +# Modify simulated data storage area. +# This script distributes simulated data under these directories +simu_actual_dirs=( +/export/c05/$USER/diarization-data +/export/c08/$USER/diarization-data +/export/c09/$USER/diarization-data +) + +# data preparation options +max_jobs_run=4 +sad_num_jobs=30 +sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3" +sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0" +sad_priors_opts="--sil-scale=0.1" + +# simulation options +simu_opts_overlap=yes +simu_opts_num_speaker_array=(1 2 3 4) +simu_opts_sil_scale_array=(2 2 5 9) +simu_opts_rvb_prob=0.5 +simu_opts_num_train=100000 +simu_opts_min_utts=10 +simu_opts_max_utts=20 + +simu_cmd="run.pl" +train_cmd="run.pl" +random_mixture_cmd="run.pl" +make_mixture_cmd="run.pl" + +. parse_options.sh || exit + +if [ $stage -le 0 ]; then + echo "prepare kaldi-style datasets" + # Prepare CALLHOME dataset. This will be used to evaluation. + if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spkall \ + || ! validate_data_dir.sh --no-text --no-feats data/callhome2_spkall; then + # imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1 + local/make_callhome.sh $callhome_dir data + # Generate two-speaker subsets + for dset in callhome1 callhome2; do + # Extract two-speaker recordings in wav.scp + copy_data_dir.sh data/${dset} data/${dset}_spkall + # Regenerate segments file from fullref.rttm + # $2: recid, $4: start_time, $5: duration, $8: speakerid + awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \ + $2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \ + data/callhome/fullref.rttm | sort > data/${dset}_spkall/segments + utils/fix_data_dir.sh data/${dset}_spkall + # Speaker ID is '[recid]_[speakerid] + awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \ + data/${dset}_spkall/segments > data/${dset}_spkall/utt2spk + utils/fix_data_dir.sh data/${dset}_spkall + # Generate rttm files for scoring + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + data/${dset}_spkall/utt2spk data/${dset}_spkall/segments \ + data/${dset}_spkall/rttm + utils/data/get_reco2dur.sh data/${dset}_spkall + done + fi + # Prepare a collection of NIST SRE and SWB data. This will be used to train, + if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then + local/make_sre.sh $data_root data + # Prepare SWB for x-vector DNN training. + local/make_swbd2_phase1.pl $swb2_phase1_train \ + data/swbd2_phase1_train + local/make_swbd2_phase2.pl $data_root/LDC99S79 \ + data/swbd2_phase2_train + local/make_swbd2_phase3.pl $data_root/LDC2002S06 \ + data/swbd2_phase3_train + local/make_swbd_cellular1.pl $data_root/LDC2001S13 \ + data/swbd_cellular1_train + local/make_swbd_cellular2.pl $data_root/LDC2004S07 \ + data/swbd_cellular2_train + # Combine swb and sre data + utils/combine_data.sh data/swb_sre_comb \ + data/swbd_cellular1_train data/swbd_cellular2_train \ + data/swbd2_phase1_train \ + data/swbd2_phase2_train data/swbd2_phase3_train data/sre + fi + # musan data. "back-ground + if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then + local/make_musan.sh $musan_root data + utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg + awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk + utils/fix_data_dir.sh data/musan_noise_bg + fi + # simu rirs 8k + if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then + mkdir -p data/simu_rirs_8k + if [ ! -e sim_rir_8k.zip ]; then + wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip + fi + unzip sim_rir_8k.zip -d data/sim_rir_8k + find $PWD/data/sim_rir_8k -iname "*.wav" \ + | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ + | sort > data/simu_rirs_8k/wav.scp + awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk + utils/fix_data_dir.sh data/simu_rirs_8k + fi + # Automatic segmentation using pretrained SAD model + # it will take one day using 30 CPU jobs: + # make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours + sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a + sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a + if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then + if [ ! -d exp/segmentation_1a ]; then + wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz + tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz + fi + steps/segmentation/detect_speech_activity.sh \ + --nj $sad_num_jobs \ + --graph-opts "$sad_graph_opts" \ + --transform-probs-opts "$sad_priors_opts" $sad_opts \ + data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \ + $sad_work_dir/swb_sre_comb || exit 1 + fi + # Extract >1.5 sec segments and split into train/valid sets + if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then + copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg + awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments + cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg + fix_data_dir.sh data/swb_sre_comb_seg + utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv + fi +fi + +simudir=data/simu +if [ $stage -le 1 ]; then + echo "simulation of mixture" + mkdir -p $simudir/.work + local/random_mixture_cmd=random_mixture.py + local/make_mixture_cmd=make_mixture.py + + for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do + simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} + simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} + for dset in swb_sre_tr swb_sre_cv; do + if [ "$dset" == "swb_sre_tr" ]; then + n_mixtures=${simu_opts_num_train} + else + n_mixtures=500 + fi + simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + # check if you have the simulation + if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then + # random mixture generation + $train_cmd $simudir/.work/random_mixture_$simuid.log \ + $random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \ + --speech_rvb_probability $simu_opts_rvb_prob \ + --sil_scale $simu_opts_sil_scale \ + data/$dset data/musan_noise_bg data/simu_rirs_8k \ + \> $simudir/.work/mixture_$simuid.scp + nj=64 + mkdir -p $simudir/wav/$simuid + # distribute simulated data to $simu_actual_dir + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp" + mkdir -p $simudir/.work/data_$simuid.$n + actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n + mkdir -p $actual + ln -nfs $actual $simudir/wav/$simuid/$n + done + utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1 + + $simu_cmd --max-jobs-run 64 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \ + $make_mixture_cmd --rate=8000 \ + $simudir/.work/mixture_$simuid.JOB.scp \ + $simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB + utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.* + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + $simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \ + $simudir/data/$simuid/rttm + utils/data/get_reco2dur.sh $simudir/data/$simuid + fi + simuid_concat=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} + mkdir -p $simudir/data/$simuid_concat + for f in `ls -F $simudir/data/$simuid | grep -v "/"`; do + cat $simudir/data/$simuid/$f >> $simudir/data/$simuid_concat/$f + done + done + done +fi + +if [ $stage -le 3 ]; then + # compose eval/callhome2_spkall + eval_set=data/eval/callhome2_spkall + if ! validate_data_dir.sh --no-text --no-feats $eval_set; then + utils/copy_data_dir.sh data/callhome2_spkall $eval_set + cp data/callhome2_spkall/rttm $eval_set/rttm + awk -v dstdir=wav/eval/callhome2_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome2_spkall/wav.scp > $eval_set/wav.scp + mkdir -p wav/eval/callhome2_spkall + wav-copy scp:data/callhome2_spkall/wav.scp scp:$eval_set/wav.scp + utils/data/get_reco2dur.sh $eval_set + fi + + # compose eval/callhome1_spkall + adapt_set=data/eval/callhome1_spkall + if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then + utils/copy_data_dir.sh data/callhome1_spkall $adapt_set + cp data/callhome1_spkall/rttm $adapt_set/rttm + awk -v dstdir=wav/eval/callhome1_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome1_spkall/wav.scp > $adapt_set/wav.scp + mkdir -p wav/eval/callhome1_spkall + wav-copy scp:data/callhome1_spkall/wav.scp scp:$adapt_set/wav.scp + utils/data/get_reco2dur.sh $adapt_set + fi +fi diff --git a/egs/callhome/eend_ola/path.sh b/egs/callhome/eend_ola/path.sh index ea3c0be2f..e1906b741 100755 --- a/egs/callhome/eend_ola/path.sh +++ b/egs/callhome/eend_ola/path.sh @@ -1,5 +1,12 @@ export FUNASR_DIR=$PWD/../../.. +# kaldi-related +export KALDI_ROOT= +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh + # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=../../../:$PYTHONPATH diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 893613752..f5afd73ea 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -27,8 +27,8 @@ callhome_average_end=100 exp_dir="." input_size=345 -stage=1 -stop_stage=4 +stage=-1 +stop_stage=-1 # exp tag tag="exp_fix" @@ -50,11 +50,26 @@ simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml) simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}" callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" -# Prepare data for training and inference -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - echo "stage 0: Prepare data for training and inference" +# simulate mixture data for training and inference +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage 0: Simulate mixture data for training and inference" + echo "The detail can be found in https://github.com/hitachi-speech/EEND" + ehco "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" + echo "This stage may take a long time, please waiting..." + KALDI_ROOT= + ln -s $KALDI_ROOT/egs/wsj/s5/steps steps + ln -s $KALDI_ROOT/egs/wsj/s5/utils utils + . local/run_prepare_shared_eda.sh fi +## Prepare data for training and inference +#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then +# echo "stage 0: Prepare data for training and inference" +# echo "The detail can be found in https://github.com/hitachi-speech/EEND" +# . ./local/ +#fi +# + # Training on simulated two-speaker data world_size=$gpu_num simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end} diff --git a/egs/callhome/diarization/sond/sond.yaml b/egs/callhome/sond/sond.yaml similarity index 100% rename from egs/callhome/diarization/sond/sond.yaml rename to egs/callhome/sond/sond.yaml diff --git a/egs/callhome/diarization/sond/sond_fbank.yaml b/egs/callhome/sond/sond_fbank.yaml similarity index 100% rename from egs/callhome/diarization/sond/sond_fbank.yaml rename to egs/callhome/sond/sond_fbank.yaml diff --git a/egs/callhome/diarization/sond/unit_test.py b/egs/callhome/sond/unit_test.py similarity index 100% rename from egs/callhome/diarization/sond/unit_test.py rename to egs/callhome/sond/unit_test.py From 1da8f85c85c7af0e5b49a04425b1c601255f6270 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 15:02:40 +0800 Subject: [PATCH 07/42] update --- .../eend_ola/local/run_prepare_shared_eda.sh | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh index f48adc54f..a256edafc 100644 --- a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh +++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh @@ -22,16 +22,16 @@ stage=0 # LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 # - musan_root # MUSAN corpus (https://www.openslr.org/17/) -callhome_dir=/export/corpora/NIST/LDC2001S97 -swb2_phase1_train=/export/corpora/LDC/LDC98S75 -data_root=/export/corpora5/LDC -musan_root=/export/corpora/JHU/musan +callhome_dir=/nfs/wangjiaming.wjm/speech-data/NIST/LDC2001S97 +swb2_phase1_train=/nfs/wangjiaming.wjm/speech-data/LDC/LDC98S75 +data_root=/nfs/wangjiaming.wjm/speech-data/LDC +musan_root=/nfs/wangjiaming.wjm/speech-data/JHU/musan # Modify simulated data storage area. # This script distributes simulated data under these directories simu_actual_dirs=( -/export/c05/$USER/diarization-data -/export/c08/$USER/diarization-data -/export/c09/$USER/diarization-data +/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s05/$USER/diarization-data +/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s08/$USER/diarization-data +/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s09/$USER/diarization-data ) # data preparation options @@ -115,9 +115,9 @@ if [ $stage -le 0 ]; then # simu rirs 8k if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then mkdir -p data/simu_rirs_8k - if [ ! -e sim_rir_8k.zip ]; then - wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip - fi +# if [ ! -e sim_rir_8k.zip ]; then +# wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip +# fi unzip sim_rir_8k.zip -d data/sim_rir_8k find $PWD/data/sim_rir_8k -iname "*.wav" \ | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ @@ -132,7 +132,7 @@ if [ $stage -le 0 ]; then sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then if [ ! -d exp/segmentation_1a ]; then - wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz +# wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz fi steps/segmentation/detect_speech_activity.sh \ From bc1641e37ef211fe1ddf046497e8d3c5af648841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 15:33:26 +0800 Subject: [PATCH 08/42] update --- egs/callhome/eend_ola/local/parse_options.sh | 97 ++++++++++++++++++++ egs/callhome/eend_ola/run.sh | 4 +- egs/callhome/eend_ola/utils | 1 - 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100755 egs/callhome/eend_ola/local/parse_options.sh delete mode 120000 egs/callhome/eend_ola/utils diff --git a/egs/callhome/eend_ola/local/parse_options.sh b/egs/callhome/eend_ola/local/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/egs/callhome/eend_ola/local/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index f5afd73ea..fb030c5ce 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -33,7 +33,7 @@ stop_stage=-1 # exp tag tag="exp_fix" -. utils/parse_options.sh || exit 1; +. local/parse_options.sh || exit 1; # Set bash to 'debug' mode, it will exit on : # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', @@ -54,7 +54,7 @@ callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then echo "stage 0: Simulate mixture data for training and inference" echo "The detail can be found in https://github.com/hitachi-speech/EEND" - ehco "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" + echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" echo "This stage may take a long time, please waiting..." KALDI_ROOT= ln -s $KALDI_ROOT/egs/wsj/s5/steps steps diff --git a/egs/callhome/eend_ola/utils b/egs/callhome/eend_ola/utils deleted file mode 120000 index fe070dd3a..000000000 --- a/egs/callhome/eend_ola/utils +++ /dev/null @@ -1 +0,0 @@ -../../aishell/transformer/utils \ No newline at end of file From 570c1793cbbaa6aadea35d874f2022a547c5682b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 18:57:28 +0800 Subject: [PATCH 09/42] update --- egs/callhome/eend_ola/local/run_blstm.sh | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 egs/callhome/eend_ola/local/run_blstm.sh diff --git a/egs/callhome/eend_ola/local/run_blstm.sh b/egs/callhome/eend_ola/local/run_blstm.sh deleted file mode 100644 index 71270a4a2..000000000 --- a/egs/callhome/eend_ola/local/run_blstm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) -# Licensed under the MIT license. -# -# BLSTM-based model experiment -./run.sh --train-config conf/blstm/train.yaml --average-start 20 --average-end 20 \ - --adapt-config conf/blstm/adapt.yaml --adapt-average-start 10 --adapt-average-end 10 \ - --infer-config conf/blstm/infer.yaml $* From 23a1c295db498baa9662ca2080b3ce063d63b9b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 18:58:16 +0800 Subject: [PATCH 10/42] update --- egs/callhome/eend_ola/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index fb030c5ce..e31457300 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -59,7 +59,7 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then KALDI_ROOT= ln -s $KALDI_ROOT/egs/wsj/s5/steps steps ln -s $KALDI_ROOT/egs/wsj/s5/utils utils - . local/run_prepare_shared_eda.sh + local/run_prepare_shared_eda.sh fi ## Prepare data for training and inference From f69ad44640583335606233acae29457fdbbeb3db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 19:01:48 +0800 Subject: [PATCH 11/42] update --- egs/callhome/eend_ola/local/make_callhome.sh | 0 egs/callhome/eend_ola/local/make_mixture.py | 0 egs/callhome/eend_ola/local/make_musan.py | 0 egs/callhome/eend_ola/local/make_musan.sh | 0 egs/callhome/eend_ola/local/make_sre.pl | 0 egs/callhome/eend_ola/local/make_sre.sh | 0 egs/callhome/eend_ola/local/make_swbd2_phase1.pl | 0 egs/callhome/eend_ola/local/make_swbd2_phase2.pl | 0 egs/callhome/eend_ola/local/make_swbd2_phase3.pl | 0 egs/callhome/eend_ola/local/make_swbd_cellular1.pl | 0 egs/callhome/eend_ola/local/make_swbd_cellular2.pl | 0 egs/callhome/eend_ola/local/model_averaging.py | 0 egs/callhome/eend_ola/local/random_mixture.py | 0 egs/callhome/eend_ola/local/run_prepare_shared_eda.sh | 0 14 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 egs/callhome/eend_ola/local/make_callhome.sh mode change 100644 => 100755 egs/callhome/eend_ola/local/make_mixture.py mode change 100644 => 100755 egs/callhome/eend_ola/local/make_musan.py mode change 100644 => 100755 egs/callhome/eend_ola/local/make_musan.sh mode change 100644 => 100755 egs/callhome/eend_ola/local/make_sre.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/make_sre.sh mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd2_phase1.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd2_phase2.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd2_phase3.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd_cellular1.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd_cellular2.pl mode change 100644 => 100755 egs/callhome/eend_ola/local/model_averaging.py mode change 100644 => 100755 egs/callhome/eend_ola/local/random_mixture.py mode change 100644 => 100755 egs/callhome/eend_ola/local/run_prepare_shared_eda.sh diff --git a/egs/callhome/eend_ola/local/make_callhome.sh b/egs/callhome/eend_ola/local/make_callhome.sh old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_mixture.py b/egs/callhome/eend_ola/local/make_mixture.py old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_musan.py b/egs/callhome/eend_ola/local/make_musan.py old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_musan.sh b/egs/callhome/eend_ola/local/make_musan.sh old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_sre.pl b/egs/callhome/eend_ola/local/make_sre.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_sre.sh b/egs/callhome/eend_ola/local/make_sre.sh old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase1.pl b/egs/callhome/eend_ola/local/make_swbd2_phase1.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase2.pl b/egs/callhome/eend_ola/local/make_swbd2_phase2.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_swbd2_phase3.pl b/egs/callhome/eend_ola/local/make_swbd2_phase3.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular2.pl b/egs/callhome/eend_ola/local/make_swbd_cellular2.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/model_averaging.py b/egs/callhome/eend_ola/local/model_averaging.py old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/random_mixture.py b/egs/callhome/eend_ola/local/random_mixture.py old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh old mode 100644 new mode 100755 From 9cdae0e1b073333090ae3f9ba15c1f9caaf1623f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 21:34:52 +0800 Subject: [PATCH 12/42] update --- egs/callhome/eend_ola/local/make_swbd_cellular1.pl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) mode change 100755 => 100644 egs/callhome/eend_ola/local/make_swbd_cellular1.pl diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl old mode 100755 new mode 100644 index e30c710e6..ede6cc2f9 --- a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl +++ b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl @@ -47,13 +47,13 @@ while () { } else { die "Unknown Gender in $line"; } - if (-e "$db_base/$wav.sph") { + if (-e "$db_base/data/$wav.sph") { $uttId = $spkr1 . "-swbdc_" . $wav ."_1"; if (!$spk2gender{$spkr1}) { $spk2gender{$spkr1} = $gender1; print GNDR "$spkr1"," $gender1\n"; } - print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/$wav.sph |\n"; + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n"; print SPKR "$uttId"," $spkr1","\n"; $uttId = $spkr2 . "-swbdc_" . $wav ."_2"; @@ -61,10 +61,10 @@ while () { $spk2gender{$spkr2} = $gender2; print GNDR "$spkr2"," $gender2\n"; } - print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/$wav.sph |\n"; + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n"; print SPKR "$uttId"," $spkr2","\n"; } else { - print STDERR "Missing $db_base/$wav.sph\n"; + print STDERR "Missing $db_base/data/$wav.sph\n"; } } } From a748256dbffc60872025e8f73f431add0155c784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 21:54:45 +0800 Subject: [PATCH 13/42] update --- egs/callhome/eend_ola/local/make_swbd_cellular1.pl | 0 egs/callhome/eend_ola/run.sh | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 egs/callhome/eend_ola/local/make_swbd_cellular1.pl diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl old mode 100644 new mode 100755 diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index e31457300..286fc29aa 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -52,7 +52,7 @@ callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" # simulate mixture data for training and inference if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - echo "stage 0: Simulate mixture data for training and inference" + echo "stage -1: Simulate mixture data for training and inference" echo "The detail can be found in https://github.com/hitachi-speech/EEND" echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" echo "This stage may take a long time, please waiting..." From 7159b0963829e49e9774fd433ed071d857396b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 18 Jul 2023 22:06:36 +0800 Subject: [PATCH 14/42] update --- egs/callhome/eend_ola/local/make_swbd_cellular1.pl | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 egs/callhome/eend_ola/local/make_swbd_cellular1.pl diff --git a/egs/callhome/eend_ola/local/make_swbd_cellular1.pl b/egs/callhome/eend_ola/local/make_swbd_cellular1.pl old mode 100755 new mode 100644 From a7b34960396fa83398e0000e0273ef8e9e6371cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Wed, 19 Jul 2023 01:49:02 +0800 Subject: [PATCH 15/42] update --- egs/callhome/eend_ola/local/make_mixture.py | 2 +- .../eend_ola/local/run_prepare_shared_eda.sh | 6 +- egs/callhome/eend_ola/run.sh | 2 +- egs/callhome/eend_ola/run_test.sh | 257 ++++++++++++++++++ funasr/modules/eend_ola/utils/kaldi_data.py | 162 +++++++++++ 5 files changed, 424 insertions(+), 5 deletions(-) create mode 100644 egs/callhome/eend_ola/run_test.sh create mode 100644 funasr/modules/eend_ola/utils/kaldi_data.py diff --git a/egs/callhome/eend_ola/local/make_mixture.py b/egs/callhome/eend_ola/local/make_mixture.py index 82d03cd60..6b159034d 100755 --- a/egs/callhome/eend_ola/local/make_mixture.py +++ b/egs/callhome/eend_ola/local/make_mixture.py @@ -13,7 +13,7 @@ import argparse import os -from eend import kaldi_data +from funasr.modules.eend_ola.utils import kaldi_data import numpy as np import math import soundfile as sf diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh index a256edafc..5431ba1de 100755 --- a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh +++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh @@ -9,7 +9,7 @@ # - data/simu_${simu_outputs} # simulation mixtures generated with various options -stage=0 +stage=1 # Modify corpus directories # - callhome_dir @@ -156,8 +156,8 @@ simudir=data/simu if [ $stage -le 1 ]; then echo "simulation of mixture" mkdir -p $simudir/.work - local/random_mixture_cmd=random_mixture.py - local/make_mixture_cmd=make_mixture.py + random_mixture_cmd=local/random_mixture.py + make_mixture_cmd=local/make_mixture.py for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 286fc29aa..b4f273945 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -31,7 +31,7 @@ stage=-1 stop_stage=-1 # exp tag -tag="exp_fix" +tag="exp1" . local/parse_options.sh || exit 1; diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh new file mode 100644 index 000000000..d00444665 --- /dev/null +++ b/egs/callhome/eend_ola/run_test.sh @@ -0,0 +1,257 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="7" +gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +count=1 + +# general configuration +simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" +simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" +callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" +simu_train_dataset=train +simu_valid_dataset=dev +callhome_train_dataset=callhome1_allspk +callhome_valid_dataset=callhome2_allspk +callhome2_wav_scp_file=wav.scp + +# model average +simu_average_2spkr_start=91 +simu_average_2spkr_end=100 +simu_average_allspkr_start=16 +simu_average_allspkr_end=25 +callhome_average_start=91 +callhome_average_end=100 + +exp_dir="." +input_size=345 +stage=5 +stop_stage=5 + +# exp tag +tag="exp1" + +. local/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +simu_2spkr_diar_config=conf/train_diar_eend_ola_simu_2spkr.yaml +simu_allspkr_diar_config=conf/train_diar_eend_ola_simu_allspkr.yaml +simu_allspkr_chunk2000_diar_config=conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml +callhome_diar_config=conf/train_diar_eend_ola_callhome_chunk2000.yaml +simu_2spkr_model_dir="baseline_$(basename "${simu_2spkr_diar_config}" .yaml)_${tag}" +simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml)_${tag}" +simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}" +callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" + +# simulate mixture data for training and inference +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Simulate mixture data for training and inference" + echo "The detail can be found in https://github.com/hitachi-speech/EEND" + echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" + echo "This stage may take a long time, please waiting..." + KALDI_ROOT= + ln -s $KALDI_ROOT/egs/wsj/s5/steps steps + ln -s $KALDI_ROOT/egs/wsj/s5/utils utils + local/run_prepare_shared_eda.sh +fi + +## Prepare data for training and inference +#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then +# echo "stage 0: Prepare data for training and inference" +# echo "The detail can be found in https://github.com/hitachi-speech/EEND" +# . ./local/ +#fi +# + +# Training on simulated two-speaker data +world_size=$gpu_num +simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end} +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Training on simulated two-speaker data" + mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir} + mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_2spkr_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats_2spkr.scp" \ + --resume true \ + --output_dir ${exp_dir}/exp/${simu_2spkr_model_dir} \ + --config $simu_2spkr_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_2spkr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$simu_2spkr_model_dir/$simu_2spkr_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${simu_2spkr_model_dir}/{$simu_average_2spkr_start..$simu_average_2spkr_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb $models +fi + +# Training on simulated all-speaker data +world_size=$gpu_num +simu_allspkr_ave_id=avg${simu_average_allspkr_start}-${simu_average_allspkr_end} +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Training on simulated all-speaker data" + mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir} + mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_allspkr_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb \ + --output_dir ${exp_dir}/exp/${simu_allspkr_model_dir} \ + --config $simu_allspkr_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$simu_allspkr_model_dir/$simu_allspkr_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${simu_allspkr_model_dir}/{$simu_average_allspkr_start..$simu_average_allspkr_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models +fi + +# Training on simulated all-speaker data with chunk_size=2000 +world_size=$gpu_num +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Training on simulated all-speaker data with chunk_size=2000" + mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} + mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log + INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${simu_feats_dir_chunk2000} \ + --train_set ${simu_train_dataset} \ + --valid_set ${simu_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb \ + --output_dir ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} \ + --config $simu_allspkr_chunk2000_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Training on callhome all-speaker data with chunk_size=2000 +world_size=$gpu_num +callhome_ave_id=avg${callhome_average_start}-${callhome_average_end} +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: Training on callhome all-speaker data with chunk_size=2000" + mkdir -p ${exp_dir}/exp/${callhome_model_dir} + mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log + INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name diar \ + --gpu_id $gpu_id \ + --use_preprocessor false \ + --input_size $input_size \ + --data_dir ${callhome_feats_dir_chunk2000} \ + --train_set ${callhome_train_dataset} \ + --valid_set ${callhome_valid_dataset} \ + --data_file_names "feats.scp" \ + --resume true \ + --init_param ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/1epoch.pb \ + --output_dir ${exp_dir}/exp/${callhome_model_dir} \ + --config $callhome_diar_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${callhome_model_dir}/log/train.log.$i 2>&1 + } & + done + wait + echo "averaging model parameters into ${exp_dir}/exp/$callhome_model_dir/$callhome_ave_id.pb" + models=`eval echo ${exp_dir}/exp/${callhome_model_dir}/{$callhome_average_start..$callhome_average_end}epoch.pb` + python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models +fi + +# inference +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Inference" + mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log + CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \ + --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ + --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ + --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ + --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 +fi \ No newline at end of file diff --git a/funasr/modules/eend_ola/utils/kaldi_data.py b/funasr/modules/eend_ola/utils/kaldi_data.py new file mode 100644 index 000000000..42f6d5ebc --- /dev/null +++ b/funasr/modules/eend_ola/utils/kaldi_data.py @@ -0,0 +1,162 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This library provides utilities for kaldi-style data directory. + + +from __future__ import print_function +import os +import sys +import numpy as np +import subprocess +import soundfile as sf +import io +from functools import lru_cache + + +def load_segments(segments_file): + """ load segments file as array """ + if not os.path.exists(segments_file): + return None + return np.loadtxt( + segments_file, + dtype=[('utt', 'object'), + ('rec', 'object'), + ('st', 'f'), + ('et', 'f')], + ndmin=1) + + +def load_segments_hash(segments_file): + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + ret[utt] = (rec, float(st), float(et)) + return ret + + +def load_segments_rechash(segments_file): + ret = {} + if not os.path.exists(segments_file): + return None + for line in open(segments_file): + utt, rec, st, et = line.strip().split() + if rec not in ret: + ret[rec] = [] + ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)}) + return ret + + +def load_wav_scp(wav_scp_file): + """ return dictionary { rec: wav_rxfilename } """ + lines = [line.strip().split(None, 1) for line in open(wav_scp_file)] + return {x[0]: x[1] for x in lines} + + +@lru_cache(maxsize=1) +def load_wav(wav_rxfilename, start=0, end=None): + """ This function reads audio file and return data in numpy.float32 array. + "lru_cache" holds recently loaded audio so that can be called + many times on the same audio file. + OPTIMIZE: controls lru_cache size for random access, + considering memory size + """ + if wav_rxfilename.endswith('|'): + # input piped command + p = subprocess.Popen(wav_rxfilename[:-1], shell=True, + stdout=subprocess.PIPE) + data, samplerate = sf.read(io.BytesIO(p.stdout.read()), + dtype='float32') + # cannot seek + data = data[start:end] + elif wav_rxfilename == '-': + # stdin + data, samplerate = sf.read(sys.stdin, dtype='float32') + # cannot seek + data = data[start:end] + else: + # normal wav file + data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) + return data, samplerate + + +def load_utt2spk(utt2spk_file): + """ returns dictionary { uttid: spkid } """ + lines = [line.strip().split(None, 1) for line in open(utt2spk_file)] + return {x[0]: x[1] for x in lines} + + +def load_spk2utt(spk2utt_file): + """ returns dictionary { spkid: list of uttids } """ + if not os.path.exists(spk2utt_file): + return None + lines = [line.strip().split() for line in open(spk2utt_file)] + return {x[0]: x[1:] for x in lines} + + +def load_reco2dur(reco2dur_file): + """ returns dictionary { recid: duration } """ + if not os.path.exists(reco2dur_file): + return None + lines = [line.strip().split(None, 1) for line in open(reco2dur_file)] + return {x[0]: float(x[1]) for x in lines} + + +def process_wav(wav_rxfilename, process): + """ This function returns preprocessed wav_rxfilename + Args: + wav_rxfilename: input + process: command which can be connected via pipe, + use stdin and stdout + Returns: + wav_rxfilename: output piped command + """ + if wav_rxfilename.endswith('|'): + # input piped command + return wav_rxfilename + process + "|" + else: + # stdin "-" or normal file + return "cat {} | {} |".format(wav_rxfilename, process) + + +def extract_segments(wavs, segments=None): + """ This function returns generator of segmented audio as + (utterance id, numpy.float32 array) + TODO?: sampling rate is not converted. + """ + if segments is not None: + # segments should be sorted by rec-id + for seg in segments: + wav = wavs[seg['rec']] + data, samplerate = load_wav(wav) + st_sample = np.rint(seg['st'] * samplerate).astype(int) + et_sample = np.rint(seg['et'] * samplerate).astype(int) + yield seg['utt'], data[st_sample:et_sample] + else: + # segments file not found, + # wav.scp is used as segmented audio list + for rec in wavs: + data, samplerate = load_wav(wavs[rec]) + yield rec, data + + +class KaldiData: + def __init__(self, data_dir): + self.data_dir = data_dir + self.segments = load_segments_rechash( + os.path.join(self.data_dir, 'segments')) + self.utt2spk = load_utt2spk( + os.path.join(self.data_dir, 'utt2spk')) + self.wavs = load_wav_scp( + os.path.join(self.data_dir, 'wav.scp')) + self.reco2dur = load_reco2dur( + os.path.join(self.data_dir, 'reco2dur')) + self.spk2utt = load_spk2utt( + os.path.join(self.data_dir, 'spk2utt')) + + def load_wav(self, recid, start=0, end=None): + data, rate = load_wav( + self.wavs[recid], start, end) + return data, rate From 92e8d4358a0c0ea323f00fa578382252c5b18732 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Wed, 19 Jul 2023 10:35:12 +0800 Subject: [PATCH 16/42] update --- egs/callhome/eend_ola/local/infer.py | 132 ++++++++++++++++++ egs/callhome/eend_ola/local/random_mixture.py | 2 +- .../eend_ola/local/run_prepare_shared_eda.sh | 2 +- 3 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 egs/callhome/eend_ola/local/infer.py diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py new file mode 100644 index 000000000..78d160d3a --- /dev/null +++ b/egs/callhome/eend_ola/local/infer.py @@ -0,0 +1,132 @@ +import argparse +import os + +import numpy as np +import soundfile as sf +import torch +import yaml +from scipy.signal import medfilt + +import funasr.models.frontend.eend_ola_feature as eend_ola_feature +from funasr.build_utils.build_model_from_file import build_model_from_file + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + type=str, + help="model config file", + ) + parser.add_argument( + "--model_file", + type=str, + help="model path", + ) + parser.add_argument( + "--output_rttm_file", + type=str, + help="output rttm path", + ) + parser.add_argument( + "--wav_scp_file", + type=str, + default="wav.scp", + help="input data path", + ) + parser.add_argument( + "--frame_shift", + type=int, + default=80, + help="frame shift", + ) + parser.add_argument( + "--frame_size", + type=int, + default=200, + help="frame size", + ) + parser.add_argument( + "--context_size", + type=int, + default=7, + help="context size", + ) + parser.add_argument( + "--sampling_rate", + type=int, + default=10, + help="sampling rate", + ) + parser.add_argument( + "--subsampling", + type=int, + default=10, + help="setting subsampling", + ) + parser.add_argument( + "--attractor_threshold", + type=float, + default=0.5, + help="threshold for selecting attractors", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + ) + args = parser.parse_args() + + with open(args.config_file) as f: + configs = yaml.safe_load(f) + for k, v in configs.items(): + if not hasattr(args, k): + setattr(args, k, v) + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + os.environ['PYTORCH_SEED'] = str(args.seed) + + model, _ = build_model_from_file(config_file=args.config_file, model_file=args.model_file, task_name="diar", + device=args.device) + model.eval() + + with open(args.wav_scp_file) as f: + wav_lines = [line.strip().split() for line in f.readlines()] + wav_items = {x[0]: x[1] for x in wav_lines} + + print("Start inference") + with open(args.output_rttm_file, "w") as wf: + for wav_id in wav_items.keys(): + print("Process wav: {}\n".format(wav_id)) + data, rate = sf.read(wav_items[wav_id]) + speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift) + speech = eend_ola_feature.transform(speech) + speech = eend_ola_feature.splice(speech, context_size=args.context_size) + speech = speech[::args.subsampling] # sampling + speech = torch.from_numpy(speech) + + with torch.no_grad(): + speech = speech.to(args.device) + ys, _, _, _ = model.estimate_sequential( + [speech], + n_speakers=None, + th=args.attractor_threshold, + shuffle=args.shuffle + ) + + a = ys[0].cpu().numpy() + a = medfilt(a, (11, 1)) + rst = [] + for spkr_id, frames in enumerate(a.T): + frames = np.pad(frames, (1, 1), 'constant') + changes, = np.where(np.diff(frames, axis=0) != 0) + fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} {:s} " + for s, e in zip(changes[::2], changes[1::2]): + st = s * args.frame_shift * args.subsampling / args.sampling_rate + dur = (e - s) * args.frame_shift * args.subsampling / args.sampling_rate + print(fmt.format( + wav_id, + st, + dur, + wav_id + "_" + str(spkr_id)), file=wf) \ No newline at end of file diff --git a/egs/callhome/eend_ola/local/random_mixture.py b/egs/callhome/eend_ola/local/random_mixture.py index 0032ef926..05d782845 100755 --- a/egs/callhome/eend_ola/local/random_mixture.py +++ b/egs/callhome/eend_ola/local/random_mixture.py @@ -42,7 +42,7 @@ The actual data dir and wav files are generated using make_mixture.py: import argparse import os -from eend import kaldi_data +from funasr.modules.eend_ola.utils import kaldi_data import random import numpy as np import json diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh index 5431ba1de..aec1ff2a9 100755 --- a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh +++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh @@ -9,7 +9,7 @@ # - data/simu_${simu_outputs} # simulation mixtures generated with various options -stage=1 +stage=0 # Modify corpus directories # - callhome_dir From 81fe1e0a098458a22a961eeb3e7d3dbcf8e43663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Wed, 19 Jul 2023 10:58:08 +0800 Subject: [PATCH 17/42] update --- egs/callhome/eend_ola/local/infer.py | 6 ++++++ egs/callhome/eend_ola/run_test.sh | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py index 78d160d3a..132246875 100644 --- a/egs/callhome/eend_ola/local/infer.py +++ b/egs/callhome/eend_ola/local/infer.py @@ -63,6 +63,12 @@ if __name__ == '__main__': default=10, help="setting subsampling", ) + parser.add_argument( + "--shuffle", + type=bool, + default=True, + help="shuffle speech in time", + ) parser.add_argument( "--attractor_threshold", type=float, diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index d00444665..c198e7375 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -253,5 +253,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ - --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 + --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \ + 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 fi \ No newline at end of file From f5bd371837cc3b89e6d387ecc84469a0e513fbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Wed, 19 Jul 2023 22:34:52 +0800 Subject: [PATCH 18/42] update --- egs/callhome/eend_ola/local/infer.py | 4 ++-- egs/callhome/eend_ola/run.sh | 24 ++++++++++++++---------- egs/callhome/eend_ola/run_test.sh | 5 ++++- funasr/models/e2e_diar_eend_ola.py | 3 +-- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py index 132246875..23e1d52c9 100644 --- a/egs/callhome/eend_ola/local/infer.py +++ b/egs/callhome/eend_ola/local/infer.py @@ -54,7 +54,7 @@ if __name__ == '__main__': parser.add_argument( "--sampling_rate", type=int, - default=10, + default=8000, help="sampling rate", ) parser.add_argument( @@ -104,7 +104,7 @@ if __name__ == '__main__': print("Start inference") with open(args.output_rttm_file, "w") as wf: for wav_id in wav_items.keys(): - print("Process wav: {}\n".format(wav_id)) + print("Process wav: {}".format(wav_id)) data, rate = sf.read(wav_items[wav_id]) speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift) speech = eend_ola_feature.transform(speech) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index b4f273945..40fb04113 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -245,13 +245,17 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models fi -## inference -#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then -# echo "Inference" -# mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log -# CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \ -# --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ -# --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ -# --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ -# --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 -#fi \ No newline at end of file +# inference and compute DER +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Inference" + mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log + CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \ + --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ + --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ + --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ + --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \ + 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 + md-eval.pl -c 0.25 \ + -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \ + -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit +fi \ No newline at end of file diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index c198e7375..9173e6fec 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -245,7 +245,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models fi -# inference +# inference and compute DER if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then echo "Inference" mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log @@ -255,4 +255,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \ 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 + md-eval.pl -c 0.25 \ + -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \ + -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit fi \ No newline at end of file diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index fda24e227..0225a7a4c 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -157,12 +157,11 @@ class DiarEENDOLAModel(FunASRModel): def estimate_sequential(self, speech: torch.Tensor, - speech_lengths: torch.Tensor, n_speakers: int = None, shuffle: bool = True, threshold: float = 0.5, **kwargs): - speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] + speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64) emb = self.forward_encoder(speech, speech_lengths) if shuffle: orders = [np.arange(e.shape[0]) for e in emb] From 21536068b9e1d94a3c0de09b6b166a786f98361f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 17:09:45 +0800 Subject: [PATCH 19/42] update --- egs/callhome/eend_ola/local/dump_feature.py | 127 +++++++++ egs/callhome/eend_ola/local/split.py | 117 ++++++++ egs/callhome/eend_ola/run.sh | 39 ++- funasr/modules/eend_ola/utils/feature.py | 286 ++++++++++++++++++++ 4 files changed, 562 insertions(+), 7 deletions(-) create mode 100644 egs/callhome/eend_ola/local/dump_feature.py create mode 100644 egs/callhome/eend_ola/local/split.py create mode 100644 funasr/modules/eend_ola/utils/feature.py diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py new file mode 100644 index 000000000..169615e1b --- /dev/null +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -0,0 +1,127 @@ +import argparse +import os + +import numpy as np + +import funasr.modules.eend_ola.utils.feature as feature +import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data + + +def _count_frames(data_len, size, step): + return int((data_len - size + step) / step) + + +def _gen_frame_indices( + data_length, size=2000, step=2000, + use_last_samples=False, + label_delay=0, + subsampling=1): + i = -1 + for i in range(_count_frames(data_length, size, step)): + yield i * step, i * step + size + if use_last_samples and i * step + size < data_length: + if data_length - (i + 1) * step - subsampling * label_delay > 0: + yield (i + 1) * step, data_length + + +class KaldiDiarizationDataset(): + def __init__( + self, + data_dir, + chunk_size=2000, + context_size=0, + frame_size=1024, + frame_shift=256, + subsampling=1, + rate=16000, + input_transform=None, + use_last_samples=False, + label_delay=0, + n_speakers=None, + ): + self.data_dir = data_dir + self.chunk_size = chunk_size + self.context_size = context_size + self.frame_size = frame_size + self.frame_shift = frame_shift + self.subsampling = subsampling + self.input_transform = input_transform + self.n_speakers = n_speakers + self.chunk_indices = [] + self.label_delay = label_delay + + self.data = kaldi_data.KaldiData(self.data_dir) + + # make chunk indices: filepath, start_frame, end_frame + for rec, path in self.data.wavs.items(): + data_len = int(self.data.reco2dur[rec] * rate / frame_shift) + data_len = int(data_len / self.subsampling) + for st, ed in _gen_frame_indices( + data_len, chunk_size, chunk_size, use_last_samples, + label_delay=self.label_delay, + subsampling=self.subsampling): + self.chunk_indices.append( + (rec, path, st * self.subsampling, ed * self.subsampling)) + print(len(self.chunk_indices), " chunks") + + +def convert(args): + f = open(out_wav_file, 'w') + dataset = KaldiDiarizationDataset( + data_dir=args.data_dir, + chunk_size=args.num_frames, + context_size=args.context_size, + input_transform=args.input_transform, + frame_size=args.frame_size, + frame_shift=args.frame_shift, + subsampling=args.subsampling, + rate=8000, + use_last_samples=True, + ) + length = len(dataset.chunk_indices) + for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): + Y, T = feature.get_labeledSTFT( + dataset.data, + rec, + st, + ed, + dataset.frame_size, + dataset.frame_shift, + dataset.n_speakers) + Y = feature.transform(Y, dataset.input_transform) + Y_spliced = feature.splice(Y, dataset.context_size) + Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) + st = '{:0>7d}'.format(st) + ed = '{:0>7d}'.format(ed) + suffix = '_' + st + '_' + ed + + parts = os.readlink('/'.join(path.split('/')[:-1])).split('/') + # print('parts: ', parts) + parts = parts[:4] + ['numpy_data'] + parts[4:] + cur_path = '/'.join(parts) + # print('cur path: ', cur_path) + out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz') + # print(out_path) + # print(cur_path) + if not os.path.exists(cur_path): + os.makedirs(cur_path) + np.savez(out_path, Y=Y_ss, T=T_ss) + if idx == length - 1: + f.write(rec + suffix + ' ' + out_path) + else: + f.write(rec + suffix + ' ' + out_path + '\n') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("data_dir") + parser.add_argument("num_frames") + parser.add_argument("context_size") + parser.add_argument("frame_size") + parser.add_argument("frame_shift") + parser.add_argument("subsampling") + + + + args = parser.parse_args() + convert(args) diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py new file mode 100644 index 000000000..6f313ccd4 --- /dev/null +++ b/egs/callhome/eend_ola/local/split.py @@ -0,0 +1,117 @@ +import argparse +import os + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('root_path', help='raw data path') + args = parser.parse_args() + + root_path = args.root_path + work_path = os.path.join(root_path, ".work") + scp_files = os.listdir(work_path) + + reco2dur_dict = {} + with open(root_path + 'reco2dur') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + reco2dur_dict[parts[0]] = parts[1] + + spk2utt_dict = {} + with open(root_path + 'spk2utt') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + spk = parts[0] + utts = parts[1:] + for utt in utts: + tmp = utt.split('data') + rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2]) + if rec in spk2utt_dict.keys(): + spk2utt_dict[rec].append((spk, utt)) + else: + spk2utt_dict[rec] = [] + spk2utt_dict[rec].append((spk, utt)) + + segment_dict = {} + with open(root_path + 'segments') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + if parts[1] in segment_dict.keys(): + segment_dict[parts[1]].append((parts[0], parts[2], parts[3])) + else: + segment_dict[parts[1]] = [] + segment_dict[parts[1]].append((parts[0], parts[2], parts[3])) + + utt2spk_dict = {} + with open(root_path + 'utt2spk') as f: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + utt = parts[0] + tmp = utt.split('data') + rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2]) + if rec in utt2spk_dict.keys(): + utt2spk_dict[rec].append((parts[0], parts[1])) + else: + utt2spk_dict[rec] = [] + utt2spk_dict[rec].append((parts[0], parts[1])) + + for file in scp_files: + scp_file = work_path + file + idx = scp_file.split('.')[-2] + reco2dur_file = work_path + 'reco2dur.' + idx + spk2utt_file = work_path + 'spk2utt.' + idx + segment_file = work_path + 'segments.' + idx + utt2spk_file = work_path + 'utt2spk.' + idx + + fpp = open(scp_file) + scp_lines = fpp.readlines() + keys = [] + for line in scp_lines: + name = line.strip().split()[0] + keys.append(name) + + with open(reco2dur_file, 'w') as f: + lines = [] + for key in keys: + string = key + ' ' + reco2dur_dict[key] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(spk2utt_file, 'w') as f: + lines = [] + for key in keys: + items = spk2utt_dict[key] + for item in items: + string = item[0] + for it in item[1:]: + string += ' ' + string += it + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(segment_file, 'w') as f: + lines = [] + for key in keys: + items = segment_dict[key] + for item in items: + string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + with open(utt2spk_file, 'w') as f: + lines = [] + for key in keys: + items = utt2spk_dict[key] + for item in items: + string = item[0] + ' ' + item[1] + lines.append(string + '\n') + lines[-1] = lines[-1][:-1] + f.writelines(lines) + + fpp.close() diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 40fb04113..cd246feee 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -8,6 +8,11 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') count=1 # general configuration +dump_cmd=utils/run.pl +nj=64 + +# feature configuration +data_dir="./data" simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" @@ -62,13 +67,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then local/run_prepare_shared_eda.sh fi -## Prepare data for training and inference -#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then -# echo "stage 0: Prepare data for training and inference" -# echo "The detail can be found in https://github.com/hitachi-speech/EEND" -# . ./local/ -#fi -# +# Prepare data for training and inference +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Prepare data for training and inference" + simu_opts_num_speaker_array=(1 2 3 4) + simu_opts_sil_scale_array=(2 2 5 9) + simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} + simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} + simu_opts_num_train=100000 + + # for simulated data of chunk500 + for dset in swb_sre_tr swb_sre_cv; do + if [ "$dset" == "swb_sre_tr" ]; then + n_mixtures=${simu_opts_num_train} + else + n_mixtures=500 + fi + simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + mkdir ${data_dir}/simu/data/${simu_data_dir}/.work + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" + done + utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 + python local/split.py ${data_dir}/simu/data/${simu_data_dir} + done +fi + # Training on simulated two-speaker data world_size=$gpu_num diff --git a/funasr/modules/eend_ola/utils/feature.py b/funasr/modules/eend_ola/utils/feature.py new file mode 100644 index 000000000..544a3521d --- /dev/null +++ b/funasr/modules/eend_ola/utils/feature.py @@ -0,0 +1,286 @@ +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) +# Licensed under the MIT license. +# +# This module is for computing audio features + +import numpy as np +import librosa + + +def get_input_dim( + frame_size, + context_size, + transform_type, +): + if transform_type.startswith('logmel23'): + frame_size = 23 + elif transform_type.startswith('logmel'): + frame_size = 40 + else: + fft_size = 1 << (frame_size - 1).bit_length() + frame_size = int(fft_size / 2) + 1 + input_dim = (2 * context_size + 1) * frame_size + return input_dim + + +def transform( + Y, + transform_type=None, + dtype=np.float32): + """ Transform STFT feature + + Args: + Y: STFT + (n_frames, n_bins)-shaped np.complex array + transform_type: + None, "log" + dtype: output data type + np.float32 is expected + Returns: + Y (numpy.array): transformed feature + """ + Y = np.abs(Y) + if not transform_type: + pass + elif transform_type == 'log': + Y = np.log(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel': + n_fft = 2 * (Y.shape[1] - 1) + sr = 16000 + n_mels = 40 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel23': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + elif transform_type == 'logmel23_mn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + elif transform_type == 'logmel23_swn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + # b = np.ones(300)/300 + # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same') + + # simple 2-means based threshoding for mean calculation + powers = np.sum(Y, axis=1) + th = (np.max(powers) + np.min(powers)) / 2.0 + for i in range(10): + th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2 + mean = np.mean(Y[powers > th, :], axis=0) + Y = Y - mean + elif transform_type == 'logmel23_mvn': + n_fft = 2 * (Y.shape[1] - 1) + sr = 8000 + n_mels = 23 + mel_basis = librosa.filters.mel(sr, n_fft, n_mels) + Y = np.dot(Y ** 2, mel_basis.T) + Y = np.log10(np.maximum(Y, 1e-10)) + mean = np.mean(Y, axis=0) + Y = Y - mean + std = np.maximum(np.std(Y, axis=0), 1e-10) + Y = Y / std + else: + raise ValueError('Unknown transform_type: %s' % transform_type) + return Y.astype(dtype) + + +def subsample(Y, T, subsampling=1): + """ Frame subsampling + """ + Y_ss = Y[::subsampling] + T_ss = T[::subsampling] + return Y_ss, T_ss + + +def splice(Y, context_size=0): + """ Frame splicing + + Args: + Y: feature + (n_frames, n_featdim)-shaped numpy array + context_size: + number of frames concatenated on left-side + if context_size = 5, 11 frames are concatenated. + + Returns: + Y_spliced: spliced feature + (n_frames, n_featdim * (2 * context_size + 1))-shaped + """ + Y_pad = np.pad( + Y, + [(context_size, context_size), (0, 0)], + 'constant') + Y_spliced = np.lib.stride_tricks.as_strided( + np.ascontiguousarray(Y_pad), + (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), + (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) + return Y_spliced + + +def stft( + data, + frame_size=1024, + frame_shift=256): + """ Compute STFT features + + Args: + data: audio signal + (n_samples,)-shaped np.float32 array + frame_size: number of samples in a frame (must be a power of two) + frame_shift: number of samples between frames + + Returns: + stft: STFT frames + (n_frames, n_bins)-shaped np.complex64 array + """ + # round up to nearest power of 2 + fft_size = 1 << (frame_size - 1).bit_length() + # HACK: The last frame is ommited + # as librosa.stft produces such an excessive frame + if len(data) % frame_shift == 0: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T[:-1] + else: + return librosa.stft(data, n_fft=fft_size, win_length=frame_size, + hop_length=frame_shift).T + + +def _count_frames(data_len, size, shift): + # HACK: Assuming librosa.stft(..., center=True) + n_frames = 1 + int(data_len / shift) + if data_len % shift == 0: + n_frames = n_frames - 1 + return n_frames + + +def get_frame_labels( + kaldi_obj, + rec, + start=0, + end=None, + frame_size=1024, + frame_shift=256, + n_speakers=None): + """ Get frame-aligned labels of given recording + Args: + kaldi_obj (KaldiData) + rec (str): recording id + start (int): start frame index + end (int): end frame index + None means the last frame of recording + frame_size (int): number of frames in a frame + frame_shift (int): number of shift samples + n_speakers (int): number of speakers + if None, the value is given from data + Returns: + T: label + (n_frames, n_speakers)-shaped np.int32 array + """ + filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] + speakers = np.unique( + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() + if n_speakers is None: + n_speakers = len(speakers) + es = end * frame_shift if end is not None else None + data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es) + n_frames = _count_frames(len(data), frame_size, frame_shift) + T = np.zeros((n_frames, n_speakers), dtype=np.int32) + if end is None: + end = n_frames + + for seg in filtered_segments: + speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) + start_frame = np.rint( + seg['st'] * rate / frame_shift).astype(int) + end_frame = np.rint( + seg['et'] * rate / frame_shift).astype(int) + rel_start = rel_end = None + if start <= start_frame and start_frame < end: + rel_start = start_frame - start + if start < end_frame and end_frame <= end: + rel_end = end_frame - start + if rel_start is not None or rel_end is not None: + T[rel_start:rel_end, speaker_index] = 1 + return T + + +def get_labeledSTFT( + kaldi_obj, + rec, start, end, frame_size, frame_shift, + n_speakers=None, + use_speaker_id=False): + """ Extracts STFT and corresponding labels + + Extracts STFT and corresponding diarization labels for + given recording id and start/end times + + Args: + kaldi_obj (KaldiData) + rec (str): recording id + start (int): start frame index + end (int): end frame index + frame_size (int): number of samples in a frame + frame_shift (int): number of shift samples + n_speakers (int): number of speakers + if None, the value is given from data + Returns: + Y: STFT + (n_frames, n_bins)-shaped np.complex64 array, + T: label + (n_frmaes, n_speakers)-shaped np.int32 array. + """ + data, rate = kaldi_obj.load_wav( + rec, start * frame_shift, end * frame_shift) + Y = stft(data, frame_size, frame_shift) + filtered_segments = kaldi_obj.segments[rec] + # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] + speakers = np.unique( + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() + if n_speakers is None: + n_speakers = len(speakers) + T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32) + + if use_speaker_id: + all_speakers = sorted(kaldi_obj.spk2utt.keys()) + S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32) + + for seg in filtered_segments: + speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) + if use_speaker_id: + all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']]) + start_frame = np.rint( + seg['st'] * rate / frame_shift).astype(int) + end_frame = np.rint( + seg['et'] * rate / frame_shift).astype(int) + rel_start = rel_end = None + if start <= start_frame and start_frame < end: + rel_start = start_frame - start + if start < end_frame and end_frame <= end: + rel_end = end_frame - start + if rel_start is not None or rel_end is not None: + T[rel_start:rel_end, speaker_index] = 1 + if use_speaker_id: + S[rel_start:rel_end, all_speaker_index] = 1 + + if use_speaker_id: + return Y, T, S + else: + return Y, T From 7fb605cc8831227c3a66d2c9da93dffa8049a5c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 17:11:30 +0800 Subject: [PATCH 20/42] update --- egs/callhome/eend_ola/run.sh | 1 - egs/callhome/eend_ola/run_test.sh | 42 ++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index cd246feee..c8f4c3cea 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -94,7 +94,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then done fi - # Training on simulated two-speaker data world_size=$gpu_num simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end} diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 9173e6fec..a3257a08e 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -8,6 +8,11 @@ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') count=1 # general configuration +dump_cmd=utils/run.pl +nj=64 + +# feature configuration +data_dir="/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20/convert_chunk2000/data" simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" @@ -27,8 +32,8 @@ callhome_average_end=100 exp_dir="." input_size=345 -stage=5 -stop_stage=5 +stage=0 +stop_stage=0 # exp tag tag="exp1" @@ -62,13 +67,32 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then local/run_prepare_shared_eda.sh fi -## Prepare data for training and inference -#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then -# echo "stage 0: Prepare data for training and inference" -# echo "The detail can be found in https://github.com/hitachi-speech/EEND" -# . ./local/ -#fi -# +# Prepare data for training and inference +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Prepare data for training and inference" + simu_opts_num_speaker_array=(1 2 3 4) + simu_opts_sil_scale_array=(2 2 5 9) + simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} + simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} + simu_opts_num_train=100000 + + # for simulated data of chunk500 + for dset in swb_sre_tr swb_sre_cv; do + if [ "$dset" == "swb_sre_tr" ]; then + n_mixtures=${simu_opts_num_train} + else + n_mixtures=500 + fi + simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + mkdir ${data_dir}/simu/data/${simu_data_dir}/.work + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" + done + utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 + python local/split.py ${data_dir}/simu/data/${simu_data_dir} + done +fi # Training on simulated two-speaker data world_size=$gpu_num From e215a76bb39f6ec35b81d1e6d131012a49d2a404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 17:20:47 +0800 Subject: [PATCH 21/42] update --- egs/callhome/eend_ola/run.sh | 2 +- egs/callhome/eend_ola/run_test.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index c8f4c3cea..7796a456a 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -83,7 +83,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then else n_mixtures=500 fi - simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} mkdir ${data_dir}/simu/data/${simu_data_dir}/.work split_scps= for n in $(seq $nj); do diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index a3257a08e..a26af5279 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -83,7 +83,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then else n_mixtures=500 fi - simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} mkdir ${data_dir}/simu/data/${simu_data_dir}/.work split_scps= for n in $(seq $nj); do From adabfc2a90c1a315e15732aae6e8691b29b820cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 17:23:17 +0800 Subject: [PATCH 22/42] update --- egs/callhome/eend_ola/run.sh | 4 +--- egs/callhome/eend_ola/run_test.sh | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 7796a456a..11e177115 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -72,8 +72,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "stage 0: Prepare data for training and inference" simu_opts_num_speaker_array=(1 2 3 4) simu_opts_sil_scale_array=(2 2 5 9) - simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} - simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} simu_opts_num_train=100000 # for simulated data of chunk500 @@ -84,7 +82,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then n_mixtures=500 fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} - mkdir ${data_dir}/simu/data/${simu_data_dir}/.work + mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work split_scps= for n in $(seq $nj); do split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index a26af5279..8ba8d5742 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -72,8 +72,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "stage 0: Prepare data for training and inference" simu_opts_num_speaker_array=(1 2 3 4) simu_opts_sil_scale_array=(2 2 5 9) - simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} - simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} simu_opts_num_train=100000 # for simulated data of chunk500 @@ -84,7 +82,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then n_mixtures=500 fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} - mkdir ${data_dir}/simu/data/${simu_data_dir}/.work + mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work split_scps= for n in $(seq $nj); do split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" From 5952aa84244e9ac86d6cbbd41150cf5b21a3dced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 18:47:26 +0800 Subject: [PATCH 23/42] update --- egs/callhome/eend_ola/local/split.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py index 6f313ccd4..0ff9e471d 100644 --- a/egs/callhome/eend_ola/local/split.py +++ b/egs/callhome/eend_ola/local/split.py @@ -11,14 +11,14 @@ if __name__ == '__main__': scp_files = os.listdir(work_path) reco2dur_dict = {} - with open(root_path + 'reco2dur') as f: + with (os.path.join(root_path, 'reco2dur')) as f: lines = f.readlines() for line in lines: parts = line.strip().split() reco2dur_dict[parts[0]] = parts[1] spk2utt_dict = {} - with open(root_path + 'spk2utt') as f: + with open(os.path.join(root_path, 'spk2utt')) as f: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -34,7 +34,7 @@ if __name__ == '__main__': spk2utt_dict[rec].append((spk, utt)) segment_dict = {} - with open(root_path + 'segments') as f: + with open(os.path.join(root_path, 'segments')) as f: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -45,7 +45,7 @@ if __name__ == '__main__': segment_dict[parts[1]].append((parts[0], parts[2], parts[3])) utt2spk_dict = {} - with open(root_path + 'utt2spk') as f: + with open(os.path.join(root_path, 'utt2spk')) as f: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -61,10 +61,10 @@ if __name__ == '__main__': for file in scp_files: scp_file = work_path + file idx = scp_file.split('.')[-2] - reco2dur_file = work_path + 'reco2dur.' + idx - spk2utt_file = work_path + 'spk2utt.' + idx - segment_file = work_path + 'segments.' + idx - utt2spk_file = work_path + 'utt2spk.' + idx + reco2dur_file = os.path.join(work_path, 'reco2dur.'.format(str(idx))) + spk2utt_file = os.path.join(work_path, 'spk2utt.'.format(str(idx))) + segment_file = os.path.join(work_path, 'segments.'.format(str(idx))) + utt2spk_file = os.path.join(work_path, 'utt2spk.'.format(str(idx))) fpp = open(scp_file) scp_lines = fpp.readlines() From c6c181e8202698b1e6d656d360f0e996e4990028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 19:08:47 +0800 Subject: [PATCH 24/42] update --- egs/callhome/eend_ola/local/split.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py index 0ff9e471d..34b7bb182 100644 --- a/egs/callhome/eend_ola/local/split.py +++ b/egs/callhome/eend_ola/local/split.py @@ -11,7 +11,7 @@ if __name__ == '__main__': scp_files = os.listdir(work_path) reco2dur_dict = {} - with (os.path.join(root_path, 'reco2dur')) as f: + with open(os.path.join(root_path, 'reco2dur')) as f: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -61,10 +61,10 @@ if __name__ == '__main__': for file in scp_files: scp_file = work_path + file idx = scp_file.split('.')[-2] - reco2dur_file = os.path.join(work_path, 'reco2dur.'.format(str(idx))) - spk2utt_file = os.path.join(work_path, 'spk2utt.'.format(str(idx))) - segment_file = os.path.join(work_path, 'segments.'.format(str(idx))) - utt2spk_file = os.path.join(work_path, 'utt2spk.'.format(str(idx))) + reco2dur_file = os.path.join(work_path, 'reco2dur.{}'.format(str(idx))) + spk2utt_file = os.path.join(work_path, 'spk2utt.{}'.format(str(idx))) + segment_file = os.path.join(work_path, 'segments.{}'.format(str(idx))) + utt2spk_file = os.path.join(work_path, 'utt2spk.{}'.format(str(idx))) fpp = open(scp_file) scp_lines = fpp.readlines() From fa618d7634b01fdd213dbe87d89af11decaf6eca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 20 Jul 2023 19:46:46 +0800 Subject: [PATCH 25/42] update --- egs/callhome/eend_ola/local/split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py index 34b7bb182..aa50b1e1e 100644 --- a/egs/callhome/eend_ola/local/split.py +++ b/egs/callhome/eend_ola/local/split.py @@ -59,7 +59,7 @@ if __name__ == '__main__': utt2spk_dict[rec].append((parts[0], parts[1])) for file in scp_files: - scp_file = work_path + file + scp_file = os.path.join(work_path, file) idx = scp_file.split('.')[-2] reco2dur_file = os.path.join(work_path, 'reco2dur.{}'.format(str(idx))) spk2utt_file = os.path.join(work_path, 'spk2utt.{}'.format(str(idx))) From 0109889f1cbbd7ff703383bfacb204d45f5d37a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 02:45:54 +0800 Subject: [PATCH 26/42] update --- egs/callhome/eend_ola/local/dump_feature.py | 105 ++++++++++++-------- egs/callhome/eend_ola/run_test.sh | 23 +++-- 2 files changed, 77 insertions(+), 51 deletions(-) diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py index 169615e1b..332edd2a1 100644 --- a/egs/callhome/eend_ola/local/dump_feature.py +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -1,10 +1,11 @@ import argparse import os -import numpy as np +from kaldiio import WriteHelper import funasr.modules.eend_ola.utils.feature as feature -import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data +from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \ + load_spk2utt, load_wav def _count_frames(data_len, size, step): @@ -24,10 +25,34 @@ def _gen_frame_indices( yield (i + 1) * step, data_length +class KaldiData: + def __init__(self, data_dir, idx): + self.data_dir = data_dir + segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx)) + self.segments = load_segments_rechash(segment_file) + + utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx)) + self.utt2spk = load_utt2spk(utt2spk_file) + + wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx)) + self.wavs = load_wav_scp(wav_file) + + reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx)) + self.reco2dur = load_reco2dur(reco2dur_file) + + spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx)) + self.spk2utt = load_spk2utt(spk2utt_file) + + def load_wav(self, recid, start=0, end=None): + data, rate = load_wav(self.wavs[recid], start, end) + return data, rate + + class KaldiDiarizationDataset(): def __init__( self, data_dir, + index, chunk_size=2000, context_size=0, frame_size=1024, @@ -40,6 +65,7 @@ class KaldiDiarizationDataset(): n_speakers=None, ): self.data_dir = data_dir + self.index = index self.chunk_size = chunk_size self.context_size = context_size self.frame_size = frame_size @@ -50,9 +76,8 @@ class KaldiDiarizationDataset(): self.chunk_indices = [] self.label_delay = label_delay - self.data = kaldi_data.KaldiData(self.data_dir) + self.data = KaldiData(self.data_dir, index) - # make chunk indices: filepath, start_frame, end_frame for rec, path in self.data.wavs.items(): data_len = int(self.data.reco2dur[rec] * rate / frame_shift) data_len = int(data_len / self.subsampling) @@ -66,62 +91,54 @@ class KaldiDiarizationDataset(): def convert(args): - f = open(out_wav_file, 'w') dataset = KaldiDiarizationDataset( data_dir=args.data_dir, + index=args.index, chunk_size=args.num_frames, context_size=args.context_size, - input_transform=args.input_transform, + input_transform="logmel23_mn", frame_size=args.frame_size, frame_shift=args.frame_shift, subsampling=args.subsampling, rate=8000, use_last_samples=True, ) - length = len(dataset.chunk_indices) - for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): - Y, T = feature.get_labeledSTFT( - dataset.data, - rec, - st, - ed, - dataset.frame_size, - dataset.frame_shift, - dataset.n_speakers) - Y = feature.transform(Y, dataset.input_transform) - Y_spliced = feature.splice(Y, dataset.context_size) - Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) - st = '{:0>7d}'.format(st) - ed = '{:0>7d}'.format(ed) - suffix = '_' + st + '_' + ed - parts = os.readlink('/'.join(path.split('/')[:-1])).split('/') - # print('parts: ', parts) - parts = parts[:4] + ['numpy_data'] + parts[4:] - cur_path = '/'.join(parts) - # print('cur path: ', cur_path) - out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz') - # print(out_path) - # print(cur_path) - if not os.path.exists(cur_path): - os.makedirs(cur_path) - np.savez(out_path, Y=Y_ss, T=T_ss) - if idx == length - 1: - f.write(rec + suffix + ' ' + out_path) - else: - f.write(rec + suffix + ' ' + out_path + '\n') + feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index)) + feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index)) + label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index)) + label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index)) + with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \ + WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer: + for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): + Y, T = feature.get_labeledSTFT( + dataset.data, + rec, + st, + ed, + dataset.frame_size, + dataset.frame_shift, + dataset.n_speakers) + Y = feature.transform(Y, dataset.input_transform) + Y_spliced = feature.splice(Y, dataset.context_size) + Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) + st = '{:0>7d}'.format(st) + ed = '{:0>7d}'.format(ed) + key = "{}_{}_{}".format(rec, st, ed) + feature_writer(key, Y_ss) + label_writer(key, T_ss.reshape(-1)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("data_dir") - parser.add_argument("num_frames") - parser.add_argument("context_size") - parser.add_argument("frame_size") - parser.add_argument("frame_shift") - parser.add_argument("subsampling") - - + parser.add_argument("output_dir") + parser.add_argument("index") + parser.add_argument("num_frames", default=500) + parser.add_argument("context_size", default=7) + parser.add_argument("frame_size", default=200) + parser.add_argument("frame_shift", default=80) + parser.add_argument("subsampling", default=10) args = parser.parse_args() convert(args) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 8ba8d5742..c6a3a7109 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -78,17 +78,26 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then for dset in swb_sre_tr swb_sre_cv; do if [ "$dset" == "swb_sre_tr" ]; then n_mixtures=${simu_opts_num_train} + dataset=train else n_mixtures=500 + dataset=dev fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} - mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work - split_scps= - for n in $(seq $nj); do - split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" - done - utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 - python local/split.py ${data_dir}/simu/data/${simu_data_dir} +# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work +# split_scps= +# for n in $(seq $nj); do +# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" +# done +# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 +# python local/split.py ${data_dir}/simu/data/${simu_data_dir} + output_dir=${data_dir}/ark_data/dump/simu_data/$dataset + mkdir -p $output_dir/.logs + $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ + python local/dump_feature.py \ + --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ + --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ + --index JOB done fi From 6205da5c22b881608b6a43ed908d41f5f5895e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 02:50:06 +0800 Subject: [PATCH 27/42] update --- egs/callhome/eend_ola/local/dump_feature.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py index 332edd2a1..5d7a0610c 100644 --- a/egs/callhome/eend_ola/local/dump_feature.py +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -131,14 +131,14 @@ def convert(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("data_dir") - parser.add_argument("output_dir") - parser.add_argument("index") - parser.add_argument("num_frames", default=500) - parser.add_argument("context_size", default=7) - parser.add_argument("frame_size", default=200) - parser.add_argument("frame_shift", default=80) - parser.add_argument("subsampling", default=10) + parser.add_argument("--data_dir", type=str) + parser.add_argument("--output_dir", type=str) + parser.add_argument("--index", type=str) + parser.add_argument("--num_frames", type=int, default=500) + parser.add_argument("--context_size", type=int, default=7) + parser.add_argument("--frame_size", type=int, default=200) + parser.add_argument("--frame_shift", type=int, default=80) + parser.add_argument("--subsampling", type=int, default=10) args = parser.parse_args() convert(args) From 311894a7aa56d0b02fbaa229be8b680ee4b48543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 02:52:39 +0800 Subject: [PATCH 28/42] update --- egs/callhome/eend_ola/local/dump_feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py index 5d7a0610c..8549c31fa 100644 --- a/egs/callhome/eend_ola/local/dump_feature.py +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -34,7 +34,7 @@ class KaldiData: utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx)) self.utt2spk = load_utt2spk(utt2spk_file) - wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx)) + wav_file = os.path.join(self.data_dir, 'wav.{}.scp'.format(idx)) self.wavs = load_wav_scp(wav_file) reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx)) From 5f3f194ffd459fd5ebe6ea46da5e532820a66060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 11:24:31 +0800 Subject: [PATCH 29/42] update --- egs/callhome/eend_ola/local/gen_feats_scp.py | 27 ++++++++++++++++++++ egs/callhome/eend_ola/run_test.sh | 24 ++++++++++------- 2 files changed, 42 insertions(+), 9 deletions(-) create mode 100644 egs/callhome/eend_ola/local/gen_feats_scp.py diff --git a/egs/callhome/eend_ola/local/gen_feats_scp.py b/egs/callhome/eend_ola/local/gen_feats_scp.py new file mode 100644 index 000000000..5667f827b --- /dev/null +++ b/egs/callhome/eend_ola/local/gen_feats_scp.py @@ -0,0 +1,27 @@ +import os +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--root_path", type=str) + parser.add_argument("--out_path", type=str) + parser.add_argument("--split_num", type=int, default=64) + args = parser.parse_args() + root_path = args.root_path + out_path = args.out_path + datasets = ["train", "dev"] + split_num = args.split_num + + for dataset in datasets: + with open(os.path.join(out_path, dataset, "feats.scp"), "w") as out_f: + for i in range(split_num): + idx = str(i + 1) + feature_file = os.path.join(root_path, dataset, "feature.scp.{}".format(idx)) + label_file = os.path.join(root_path, dataset, "label.scp.{}".format(idx)) + with open(feature_file) as ff, open(label_file) as fl: + ff_lines = ff.readlines() + fl_lines = fl.readlines() + for ff_line, fl_line in zip(ff_lines, fl_lines): + sample_name, f_path = ff_line.strip().split() + _, l_path = fl_line.strip().split() + out_f.write("{} {} {}\n".format(sample_name, f_path, l_path)) \ No newline at end of file diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index c6a3a7109..31b177e4a 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -12,7 +12,7 @@ dump_cmd=utils/run.pl nj=64 # feature configuration -data_dir="/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20/convert_chunk2000/data" +data_dir="/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20/convert_test/data" simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" @@ -74,7 +74,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then simu_opts_sil_scale_array=(2 2 5 9) simu_opts_num_train=100000 - # for simulated data of chunk500 + # for simulated data of chunk500 and chunk2000 for dset in swb_sre_tr swb_sre_cv; do if [ "$dset" == "swb_sre_tr" ]; then n_mixtures=${simu_opts_num_train} @@ -91,13 +91,19 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # done # utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 # python local/split.py ${data_dir}/simu/data/${simu_data_dir} - output_dir=${data_dir}/ark_data/dump/simu_data/$dataset - mkdir -p $output_dir/.logs - $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ - python local/dump_feature.py \ - --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ - --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ - --index JOB +# # for chunk_size=500 +# output_dir=${data_dir}/ark_data/dump/simu_data/$dataset +# mkdir -p $output_dir/.logs +# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ +# python local/dump_feature.py \ +# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ +# --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ +# --index JOB + mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset + python local_rank/gen_feats_scp.py \ + --root_path ${data_dir}/ark_data/dump/simu_data \ + --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ + --split_num $nj done fi From e79c9a801e7e7458ce6894fa85178fa974dfd18b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 11:28:18 +0800 Subject: [PATCH 30/42] update --- egs/callhome/eend_ola/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 31b177e4a..57d6418cc 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -100,7 +100,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ # --index JOB mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset - python local_rank/gen_feats_scp.py \ + python local/gen_feats_scp.py \ --root_path ${data_dir}/ark_data/dump/simu_data \ --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ --split_num $nj From 13a1ba2b1d12dd8d144660e7678645db41d67960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 14:16:36 +0800 Subject: [PATCH 31/42] update --- egs/callhome/eend_ola/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 57d6418cc..02c51f69b 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -102,7 +102,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset python local/gen_feats_scp.py \ --root_path ${data_dir}/ark_data/dump/simu_data \ - --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ + --out_path ${data_dir}/ark_data/dump/simu_data/data \ --split_num $nj done fi From d273f7e12693e5b366cbf2ff7d01dde0264b01d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 14:26:15 +0800 Subject: [PATCH 32/42] update --- egs/callhome/eend_ola/local/gen_feats_scp.py | 26 +++++++++----------- egs/callhome/eend_ola/run_test.sh | 4 +-- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/egs/callhome/eend_ola/local/gen_feats_scp.py b/egs/callhome/eend_ola/local/gen_feats_scp.py index 5667f827b..88a94f218 100644 --- a/egs/callhome/eend_ola/local/gen_feats_scp.py +++ b/egs/callhome/eend_ola/local/gen_feats_scp.py @@ -9,19 +9,17 @@ if __name__ == '__main__': args = parser.parse_args() root_path = args.root_path out_path = args.out_path - datasets = ["train", "dev"] split_num = args.split_num - for dataset in datasets: - with open(os.path.join(out_path, dataset, "feats.scp"), "w") as out_f: - for i in range(split_num): - idx = str(i + 1) - feature_file = os.path.join(root_path, dataset, "feature.scp.{}".format(idx)) - label_file = os.path.join(root_path, dataset, "label.scp.{}".format(idx)) - with open(feature_file) as ff, open(label_file) as fl: - ff_lines = ff.readlines() - fl_lines = fl.readlines() - for ff_line, fl_line in zip(ff_lines, fl_lines): - sample_name, f_path = ff_line.strip().split() - _, l_path = fl_line.strip().split() - out_f.write("{} {} {}\n".format(sample_name, f_path, l_path)) \ No newline at end of file + with open(os.path.join(out_path, "feats.scp"), "w") as out_f: + for i in range(split_num): + idx = str(i + 1) + feature_file = os.path.join(root_path, "feature.scp.{}".format(idx)) + label_file = os.path.join(root_path, "label.scp.{}".format(idx)) + with open(feature_file) as ff, open(label_file) as fl: + ff_lines = ff.readlines() + fl_lines = fl.readlines() + for ff_line, fl_line in zip(ff_lines, fl_lines): + sample_name, f_path = ff_line.strip().split() + _, l_path = fl_line.strip().split() + out_f.write("{} {} {}\n".format(sample_name, f_path, l_path)) \ No newline at end of file diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 02c51f69b..77aa08db8 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -101,8 +101,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # --index JOB mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/simu_data \ - --out_path ${data_dir}/ark_data/dump/simu_data/data \ + --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ + --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ --split_num $nj done fi From c8680dcb7be0cd0a254655cd11c8ffe873acdb40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 21 Jul 2023 14:35:04 +0800 Subject: [PATCH 33/42] update --- egs/callhome/eend_ola/run_test.sh | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 77aa08db8..a824a6832 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -75,7 +75,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then simu_opts_num_train=100000 # for simulated data of chunk500 and chunk2000 - for dset in swb_sre_tr swb_sre_cv; do + for dset in swb_sre_cv swb_sre_tr; do if [ "$dset" == "swb_sre_tr" ]; then n_mixtures=${simu_opts_num_train} dataset=train @@ -84,21 +84,21 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then dataset=dev fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} -# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work -# split_scps= -# for n in $(seq $nj); do -# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" -# done -# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 -# python local/split.py ${data_dir}/simu/data/${simu_data_dir} -# # for chunk_size=500 -# output_dir=${data_dir}/ark_data/dump/simu_data/$dataset -# mkdir -p $output_dir/.logs -# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ -# python local/dump_feature.py \ -# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ -# --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ -# --index JOB + mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" + done + utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 + python local/split.py ${data_dir}/simu/data/${simu_data_dir} + # for chunk_size=500 + output_dir=${data_dir}/ark_data/dump/simu_data/$dataset + mkdir -p $output_dir/.logs + $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ + python local/dump_feature.py \ + --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ + --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ + --index JOB mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset python local/gen_feats_scp.py \ --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ From 3aad0e15ecf53aa22e89c82f48fcf356df16df20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sun, 23 Jul 2023 00:12:57 +0800 Subject: [PATCH 34/42] update --- egs/callhome/eend_ola/local/dump_feature.py | 2 +- egs/callhome/eend_ola/run_test.sh | 92 ++++++++++++++------- 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py index 8549c31fa..5d7a0610c 100644 --- a/egs/callhome/eend_ola/local/dump_feature.py +++ b/egs/callhome/eend_ola/local/dump_feature.py @@ -34,7 +34,7 @@ class KaldiData: utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx)) self.utt2spk = load_utt2spk(utt2spk_file) - wav_file = os.path.join(self.data_dir, 'wav.{}.scp'.format(idx)) + wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx)) self.wavs = load_wav_scp(wav_file) reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx)) diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index a824a6832..188b61ed5 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -3,7 +3,7 @@ . ./path.sh || exit 1; # machines configuration -CUDA_VISIBLE_DEVICES="7" +CUDA_VISIBLE_DEVICES="0" gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') count=1 @@ -12,7 +12,7 @@ dump_cmd=utils/run.pl nj=64 # feature configuration -data_dir="/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20/convert_test/data" +data_dir="./data" simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" @@ -74,36 +74,68 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then simu_opts_sil_scale_array=(2 2 5 9) simu_opts_num_train=100000 - # for simulated data of chunk500 and chunk2000 - for dset in swb_sre_cv swb_sre_tr; do - if [ "$dset" == "swb_sre_tr" ]; then - n_mixtures=${simu_opts_num_train} - dataset=train - else - n_mixtures=500 - dataset=dev - fi - simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} - mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work - split_scps= - for n in $(seq $nj); do - split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" - done - utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 - python local/split.py ${data_dir}/simu/data/${simu_data_dir} - # for chunk_size=500 - output_dir=${data_dir}/ark_data/dump/simu_data/$dataset - mkdir -p $output_dir/.logs - $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ +# # for simulated data of chunk500 and chunk2000 +# for dset in swb_sre_cv swb_sre_tr; do +# if [ "$dset" == "swb_sre_tr" ]; then +# n_mixtures=${simu_opts_num_train} +# dataset=train +# else +# n_mixtures=500 +# dataset=dev +# fi +# simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} +# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work +# split_scps= +# for n in $(seq $nj); do +# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.scp.$n" +# done +# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 +# python local/split.py ${data_dir}/simu/data/${simu_data_dir} +# # for chunk_size=500 +# output_dir=${data_dir}/ark_data/dump/simu_data/$dataset +# mkdir -p $output_dir/.logs +# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ +# python local/dump_feature.py \ +# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ +# --output_dir $output_dir \ +# --index JOB +# mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset +# python local/gen_feats_scp.py \ +# --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ +# --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ +# --split_num $nj +# grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp +# # for chunk_size=2000 +# output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset +# mkdir -p $output_dir/.logs +# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ +# python local/dump_feature.py \ +# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ +# --output_dir $output_dir \ +# --index JOB \ +# --num_frames 2000 +# mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset +# python local/gen_feats_scp.py \ +# --root_path ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset \ +# --out_path ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset \ +# --split_num $nj +# grep "ns2" ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats_2spkr.scp +# done + + # for callhome data + for dset in callhome1_spkall callhome2_spkall; do + find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \; + output_dir=${data_dir}/ark_data/dump/callhome/$dset python local/dump_feature.py \ - --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ - --output_dir ${data_dir}/ark_data/dump/simu_data/$dataset \ - --index JOB - mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset + --data_dir $data_dir/eval/$dset \ + --output_dir $output_dir \ + --index 1 \ + --num_frames 2000 + mkdir -p ${data_dir}/ark_data/dump/callhome/data/$dset python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ - --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ - --split_num $nj + --root_path ${data_dir}/ark_data/dump/callhome/$dset \ + --out_path ${data_dir}/ark_data/dump/callhome/data/$dset \ + --split_num 1 done fi From d0b0f85794eb77268f0994a1c7a81ebcebf65a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sun, 23 Jul 2023 00:34:33 +0800 Subject: [PATCH 35/42] update --- egs/callhome/eend_ola/run.sh | 73 +++++++++++++++++++++++++------ egs/callhome/eend_ola/run_test.sh | 9 ++-- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 11e177115..c3a3f320a 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -3,7 +3,7 @@ . ./path.sh || exit 1; # machines configuration -CUDA_VISIBLE_DEVICES="7" +CUDA_VISIBLE_DEVICES="0" gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') count=1 @@ -13,14 +13,13 @@ nj=64 # feature configuration data_dir="./data" -simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" -simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" -callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" +simu_feats_dir=$data_dir/simu/ark_data/dump/simu_data/data +simu_feats_dir_chunk2000=$data_dir/simu/ark_data/dump/simu_data_chunk2000/data +callhome_feats_dir_chunk2000=$data_dir/simu/ark_data/dump/callhome_chunk2000/data simu_train_dataset=train simu_valid_dataset=dev -callhome_train_dataset=callhome1_allspk -callhome_valid_dataset=callhome2_allspk -callhome2_wav_scp_file=wav.scp +callhome_train_dataset=callhome1_spkall +callhome_valid_dataset=callhome2_spkall # model average simu_average_2spkr_start=91 @@ -32,8 +31,8 @@ callhome_average_end=100 exp_dir="." input_size=345 -stage=-1 -stop_stage=-1 +stage=0 +stop_stage=5 # exp tag tag="exp1" @@ -74,21 +73,69 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then simu_opts_sil_scale_array=(2 2 5 9) simu_opts_num_train=100000 - # for simulated data of chunk500 - for dset in swb_sre_tr swb_sre_cv; do + # for simulated data of chunk500 and chunk2000 + for dset in swb_sre_cv swb_sre_tr; do if [ "$dset" == "swb_sre_tr" ]; then n_mixtures=${simu_opts_num_train} + dataset=train else n_mixtures=500 + dataset=dev fi simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work split_scps= for n in $(seq $nj); do - split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp" + split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.scp.$n" done utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 python local/split.py ${data_dir}/simu/data/${simu_data_dir} + # for chunk_size=500 + output_dir=${data_dir}/ark_data/dump/simu_data/$dataset + mkdir -p $output_dir/.logs + $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ + python local/dump_feature.py \ + --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ + --output_dir $output_dir \ + --index JOB + mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset + python local/gen_feats_scp.py \ + --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ + --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ + --split_num $nj + grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp + # for chunk_size=2000 + output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset + mkdir -p $output_dir/.logs + $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ + python local/dump_feature.py \ + --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ + --output_dir $output_dir \ + --index JOB \ + --num_frames 2000 + mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset + python local/gen_feats_scp.py \ + --root_path ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset \ + --out_path ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset \ + --split_num $nj + grep "ns2" ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats_2spkr.scp + done + + # for callhome data + for dset in callhome1_spkall callhome2_spkall; do + find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \; + output_dir=${data_dir}/ark_data/dump/callhome/$dset + mkdir -p $output_dir + python local/dump_feature.py \ + --data_dir $data_dir/eval/$dset \ + --output_dir $output_dir \ + --index 1 \ + --num_frames 2000 + mkdir -p ${data_dir}/ark_data/dump/callhome/data/$dset + python local/gen_feats_scp.py \ + --root_path ${data_dir}/ark_data/dump/callhome/$dset \ + --out_path ${data_dir}/ark_data/dump/callhome/data/$dset \ + --split_num 1 done fi @@ -275,7 +322,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ - --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \ + --wav_scp_file $data_dir/eval/callhome2_spkall/wav.scp \ 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 md-eval.pl -c 0.25 \ -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \ diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh index 188b61ed5..36ba1e7ca 100644 --- a/egs/callhome/eend_ola/run_test.sh +++ b/egs/callhome/eend_ola/run_test.sh @@ -125,16 +125,17 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # for callhome data for dset in callhome1_spkall callhome2_spkall; do find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \; - output_dir=${data_dir}/ark_data/dump/callhome/$dset + output_dir=${data_dir}/ark_data/dump/callhome_chunk2000/$dset + mkdir -p $output_dir python local/dump_feature.py \ --data_dir $data_dir/eval/$dset \ --output_dir $output_dir \ --index 1 \ --num_frames 2000 - mkdir -p ${data_dir}/ark_data/dump/callhome/data/$dset + mkdir -p ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/callhome/$dset \ - --out_path ${data_dir}/ark_data/dump/callhome/data/$dset \ + --root_path ${data_dir}/ark_data/dump/callhome_chunk2000/$dset \ + --out_path ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset \ --split_num 1 done fi From 9a43b5607d5bb7958c27d06d0252fee1dc858f3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sun, 23 Jul 2023 16:47:53 +0800 Subject: [PATCH 36/42] update --- .../eend_ola/local/run_prepare_shared_eda.sh | 14 +++++++------- egs/callhome/eend_ola/local/split.py | 2 +- egs/callhome/eend_ola/run.sh | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh index aec1ff2a9..f1019d60b 100755 --- a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh +++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh @@ -22,16 +22,16 @@ stage=0 # LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 # - musan_root # MUSAN corpus (https://www.openslr.org/17/) -callhome_dir=/nfs/wangjiaming.wjm/speech-data/NIST/LDC2001S97 -swb2_phase1_train=/nfs/wangjiaming.wjm/speech-data/LDC/LDC98S75 -data_root=/nfs/wangjiaming.wjm/speech-data/LDC -musan_root=/nfs/wangjiaming.wjm/speech-data/JHU/musan +callhome_dir= +swb2_phase1_train= +data_root= +musan_root= # Modify simulated data storage area. # This script distributes simulated data under these directories simu_actual_dirs=( -/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s05/$USER/diarization-data -/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s08/$USER/diarization-data -/nfs/wangjiaming.wjm/EEND_DATA_sad30_snr10n15n20_funasr_test/s09/$USER/diarization-data +./s05/$USER/diarization-data +./s08/$USER/diarization-data +./s09/$USER/diarization-data ) # data preparation options diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py index aa50b1e1e..7ad1badd4 100644 --- a/egs/callhome/eend_ola/local/split.py +++ b/egs/callhome/eend_ola/local/split.py @@ -60,7 +60,7 @@ if __name__ == '__main__': for file in scp_files: scp_file = os.path.join(work_path, file) - idx = scp_file.split('.')[-2] + idx = scp_file.split('.')[-1] reco2dur_file = os.path.join(work_path, 'reco2dur.{}'.format(str(idx))) spk2utt_file = os.path.join(work_path, 'spk2utt.{}'.format(str(idx))) segment_file = os.path.join(work_path, 'segments.{}'.format(str(idx))) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index c3a3f320a..ff6b75b1f 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -13,9 +13,9 @@ nj=64 # feature configuration data_dir="./data" -simu_feats_dir=$data_dir/simu/ark_data/dump/simu_data/data -simu_feats_dir_chunk2000=$data_dir/simu/ark_data/dump/simu_data_chunk2000/data -callhome_feats_dir_chunk2000=$data_dir/simu/ark_data/dump/callhome_chunk2000/data +simu_feats_dir=$data_dir/ark_data/dump/simu_data/data +simu_feats_dir_chunk2000=$data_dir/ark_data/dump/simu_data_chunk2000/data +callhome_feats_dir_chunk2000=$data_dir/ark_data/dump/callhome_chunk2000/data simu_train_dataset=train simu_valid_dataset=dev callhome_train_dataset=callhome1_spkall @@ -31,7 +31,7 @@ callhome_average_end=100 exp_dir="." input_size=345 -stage=0 +stage=1 stop_stage=5 # exp tag From 197196c712d1f9ce4811edc4d27359b4fce0405a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 27 Jul 2023 11:23:22 +0800 Subject: [PATCH 37/42] update --- egs/callhome/eend_ola/run.sh | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index ff6b75b1f..1e7f64b81 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -99,10 +99,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --output_dir $output_dir \ --index JOB mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset - python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ - --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ - --split_num $nj + cat ${data_dir}/ark_data/dump/simu_data/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp + cat ${data_dir}/ark_data/dump/simu_data/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp + paste -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp # for chunk_size=2000 output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset @@ -114,10 +113,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --index JOB \ --num_frames 2000 mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset - python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset \ - --out_path ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset \ - --split_num $nj + cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp + cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp + paste -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp grep "ns2" ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats_2spkr.scp done @@ -132,10 +130,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --index 1 \ --num_frames 2000 mkdir -p ${data_dir}/ark_data/dump/callhome/data/$dset - python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/callhome/$dset \ - --out_path ${data_dir}/ark_data/dump/callhome/data/$dset \ - --split_num 1 + paste -d" " ${data_dir}/ark_data/dump/callhome/$dset/feature.scp.1 <(cut -f2 -d" " ${data_dir}/ark_data/dump/callhome/$dset/label.scp.1) > ${data_dir}/ark_data/dump/callhome/data/$dset/feats.scp done fi From 58c072cfbe547fa68dff49ea3b00b753e27457eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sun, 30 Jul 2023 23:15:15 +0800 Subject: [PATCH 38/42] update --- egs/callhome/eend_ola/run.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 1e7f64b81..6ad4a0db8 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -122,15 +122,15 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # for callhome data for dset in callhome1_spkall callhome2_spkall; do find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \; - output_dir=${data_dir}/ark_data/dump/callhome/$dset + output_dir=${data_dir}/ark_data/dump/callhome_chunk2000/$dset mkdir -p $output_dir python local/dump_feature.py \ --data_dir $data_dir/eval/$dset \ --output_dir $output_dir \ --index 1 \ --num_frames 2000 - mkdir -p ${data_dir}/ark_data/dump/callhome/data/$dset - paste -d" " ${data_dir}/ark_data/dump/callhome/$dset/feature.scp.1 <(cut -f2 -d" " ${data_dir}/ark_data/dump/callhome/$dset/label.scp.1) > ${data_dir}/ark_data/dump/callhome/data/$dset/feats.scp + mkdir -p ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset + paste -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/feature.scp.1 <(cut -f2 -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/label.scp.1) > ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset/feats.scp done fi From 154e45b1c40de471615f3d16983f53e688d22e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Mon, 31 Jul 2023 01:14:25 +0800 Subject: [PATCH 39/42] update --- egs/callhome/eend_ola/run.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index 6ad4a0db8..d44241234 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -116,7 +116,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp paste -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp - grep "ns2" ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats_2spkr.scp done # for callhome data @@ -320,6 +319,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --wav_scp_file $data_dir/eval/callhome2_spkall/wav.scp \ 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 md-eval.pl -c 0.25 \ - -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \ + -r ${data_dir}/eval/${callhome_valid_dataset}/rttm \ -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit fi \ No newline at end of file From 47343b5c2f4e1256f60f46d8da0aa2e5de39b6c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sat, 5 Aug 2023 17:53:08 +0800 Subject: [PATCH 40/42] init repo --- egs/callhome/eend_ola/run_test.sh | 331 ------------------------------ 1 file changed, 331 deletions(-) delete mode 100644 egs/callhome/eend_ola/run_test.sh diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh deleted file mode 100644 index 36ba1e7ca..000000000 --- a/egs/callhome/eend_ola/run_test.sh +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env bash - -. ./path.sh || exit 1; - -# machines configuration -CUDA_VISIBLE_DEVICES="0" -gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -count=1 - -# general configuration -dump_cmd=utils/run.pl -nj=64 - -# feature configuration -data_dir="./data" -simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data" -simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data" -callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data" -simu_train_dataset=train -simu_valid_dataset=dev -callhome_train_dataset=callhome1_allspk -callhome_valid_dataset=callhome2_allspk -callhome2_wav_scp_file=wav.scp - -# model average -simu_average_2spkr_start=91 -simu_average_2spkr_end=100 -simu_average_allspkr_start=16 -simu_average_allspkr_end=25 -callhome_average_start=91 -callhome_average_end=100 - -exp_dir="." -input_size=345 -stage=0 -stop_stage=0 - -# exp tag -tag="exp1" - -. local/parse_options.sh || exit 1; - -# Set bash to 'debug' mode, it will exit on : -# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', -set -e -set -u -set -o pipefail - -simu_2spkr_diar_config=conf/train_diar_eend_ola_simu_2spkr.yaml -simu_allspkr_diar_config=conf/train_diar_eend_ola_simu_allspkr.yaml -simu_allspkr_chunk2000_diar_config=conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml -callhome_diar_config=conf/train_diar_eend_ola_callhome_chunk2000.yaml -simu_2spkr_model_dir="baseline_$(basename "${simu_2spkr_diar_config}" .yaml)_${tag}" -simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml)_${tag}" -simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}" -callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}" - -# simulate mixture data for training and inference -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - echo "stage -1: Simulate mixture data for training and inference" - echo "The detail can be found in https://github.com/hitachi-speech/EEND" - echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh" - echo "This stage may take a long time, please waiting..." - KALDI_ROOT= - ln -s $KALDI_ROOT/egs/wsj/s5/steps steps - ln -s $KALDI_ROOT/egs/wsj/s5/utils utils - local/run_prepare_shared_eda.sh -fi - -# Prepare data for training and inference -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - echo "stage 0: Prepare data for training and inference" - simu_opts_num_speaker_array=(1 2 3 4) - simu_opts_sil_scale_array=(2 2 5 9) - simu_opts_num_train=100000 - -# # for simulated data of chunk500 and chunk2000 -# for dset in swb_sre_cv swb_sre_tr; do -# if [ "$dset" == "swb_sre_tr" ]; then -# n_mixtures=${simu_opts_num_train} -# dataset=train -# else -# n_mixtures=500 -# dataset=dev -# fi -# simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} -# mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work -# split_scps= -# for n in $(seq $nj); do -# split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.scp.$n" -# done -# utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1 -# python local/split.py ${data_dir}/simu/data/${simu_data_dir} -# # for chunk_size=500 -# output_dir=${data_dir}/ark_data/dump/simu_data/$dataset -# mkdir -p $output_dir/.logs -# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ -# python local/dump_feature.py \ -# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ -# --output_dir $output_dir \ -# --index JOB -# mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset -# python local/gen_feats_scp.py \ -# --root_path ${data_dir}/ark_data/dump/simu_data/$dataset \ -# --out_path ${data_dir}/ark_data/dump/simu_data/data/$dataset \ -# --split_num $nj -# grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp -# # for chunk_size=2000 -# output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset -# mkdir -p $output_dir/.logs -# $dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \ -# python local/dump_feature.py \ -# --data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \ -# --output_dir $output_dir \ -# --index JOB \ -# --num_frames 2000 -# mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset -# python local/gen_feats_scp.py \ -# --root_path ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset \ -# --out_path ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset \ -# --split_num $nj -# grep "ns2" ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats_2spkr.scp -# done - - # for callhome data - for dset in callhome1_spkall callhome2_spkall; do - find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \; - output_dir=${data_dir}/ark_data/dump/callhome_chunk2000/$dset - mkdir -p $output_dir - python local/dump_feature.py \ - --data_dir $data_dir/eval/$dset \ - --output_dir $output_dir \ - --index 1 \ - --num_frames 2000 - mkdir -p ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset - python local/gen_feats_scp.py \ - --root_path ${data_dir}/ark_data/dump/callhome_chunk2000/$dset \ - --out_path ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset \ - --split_num 1 - done -fi - -# Training on simulated two-speaker data -world_size=$gpu_num -simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end} -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - echo "stage 1: Training on simulated two-speaker data" - mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir} - mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}/log - INIT_FILE=${exp_dir}/exp/${simu_2spkr_model_dir}/ddp_init - if [ -f $INIT_FILE ];then - rm -f $INIT_FILE - fi - init_method=file://$(readlink -f $INIT_FILE) - echo "$0: init method is $init_method" - for ((i = 0; i < $gpu_num; ++i)); do - { - rank=$i - local_rank=$i - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - train.py \ - --task_name diar \ - --gpu_id $gpu_id \ - --use_preprocessor false \ - --input_size $input_size \ - --data_dir ${simu_feats_dir} \ - --train_set ${simu_train_dataset} \ - --valid_set ${simu_valid_dataset} \ - --data_file_names "feats_2spkr.scp" \ - --resume true \ - --output_dir ${exp_dir}/exp/${simu_2spkr_model_dir} \ - --config $simu_2spkr_diar_config \ - --ngpu $gpu_num \ - --num_worker_count $count \ - --dist_init_method $init_method \ - --dist_world_size $world_size \ - --dist_rank $rank \ - --local_rank $local_rank 1> ${exp_dir}/exp/${simu_2spkr_model_dir}/log/train.log.$i 2>&1 - } & - done - wait - echo "averaging model parameters into ${exp_dir}/exp/$simu_2spkr_model_dir/$simu_2spkr_ave_id.pb" - models=`eval echo ${exp_dir}/exp/${simu_2spkr_model_dir}/{$simu_average_2spkr_start..$simu_average_2spkr_end}epoch.pb` - python local/model_averaging.py ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb $models -fi - -# Training on simulated all-speaker data -world_size=$gpu_num -simu_allspkr_ave_id=avg${simu_average_allspkr_start}-${simu_average_allspkr_end} -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - echo "stage 2: Training on simulated all-speaker data" - mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir} - mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}/log - INIT_FILE=${exp_dir}/exp/${simu_allspkr_model_dir}/ddp_init - if [ -f $INIT_FILE ];then - rm -f $INIT_FILE - fi - init_method=file://$(readlink -f $INIT_FILE) - echo "$0: init method is $init_method" - for ((i = 0; i < $gpu_num; ++i)); do - { - rank=$i - local_rank=$i - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - train.py \ - --task_name diar \ - --gpu_id $gpu_id \ - --use_preprocessor false \ - --input_size $input_size \ - --data_dir ${simu_feats_dir} \ - --train_set ${simu_train_dataset} \ - --valid_set ${simu_valid_dataset} \ - --data_file_names "feats.scp" \ - --resume true \ - --init_param ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb \ - --output_dir ${exp_dir}/exp/${simu_allspkr_model_dir} \ - --config $simu_allspkr_diar_config \ - --ngpu $gpu_num \ - --num_worker_count $count \ - --dist_init_method $init_method \ - --dist_world_size $world_size \ - --dist_rank $rank \ - --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_model_dir}/log/train.log.$i 2>&1 - } & - done - wait - echo "averaging model parameters into ${exp_dir}/exp/$simu_allspkr_model_dir/$simu_allspkr_ave_id.pb" - models=`eval echo ${exp_dir}/exp/${simu_allspkr_model_dir}/{$simu_average_allspkr_start..$simu_average_allspkr_end}epoch.pb` - python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models -fi - -# Training on simulated all-speaker data with chunk_size=2000 -world_size=$gpu_num -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - echo "stage 3: Training on simulated all-speaker data with chunk_size=2000" - mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} - mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log - INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init - if [ -f $INIT_FILE ];then - rm -f $INIT_FILE - fi - init_method=file://$(readlink -f $INIT_FILE) - echo "$0: init method is $init_method" - for ((i = 0; i < $gpu_num; ++i)); do - { - rank=$i - local_rank=$i - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - train.py \ - --task_name diar \ - --gpu_id $gpu_id \ - --use_preprocessor false \ - --input_size $input_size \ - --data_dir ${simu_feats_dir_chunk2000} \ - --train_set ${simu_train_dataset} \ - --valid_set ${simu_valid_dataset} \ - --data_file_names "feats.scp" \ - --resume true \ - --init_param ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb \ - --output_dir ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} \ - --config $simu_allspkr_chunk2000_diar_config \ - --ngpu $gpu_num \ - --num_worker_count $count \ - --dist_init_method $init_method \ - --dist_world_size $world_size \ - --dist_rank $rank \ - --local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log/train.log.$i 2>&1 - } & - done - wait -fi - -# Training on callhome all-speaker data with chunk_size=2000 -world_size=$gpu_num -callhome_ave_id=avg${callhome_average_start}-${callhome_average_end} -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - echo "stage 4: Training on callhome all-speaker data with chunk_size=2000" - mkdir -p ${exp_dir}/exp/${callhome_model_dir} - mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log - INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init - if [ -f $INIT_FILE ];then - rm -f $INIT_FILE - fi - init_method=file://$(readlink -f $INIT_FILE) - echo "$0: init method is $init_method" - for ((i = 0; i < $gpu_num; ++i)); do - { - rank=$i - local_rank=$i - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - train.py \ - --task_name diar \ - --gpu_id $gpu_id \ - --use_preprocessor false \ - --input_size $input_size \ - --data_dir ${callhome_feats_dir_chunk2000} \ - --train_set ${callhome_train_dataset} \ - --valid_set ${callhome_valid_dataset} \ - --data_file_names "feats.scp" \ - --resume true \ - --init_param ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/1epoch.pb \ - --output_dir ${exp_dir}/exp/${callhome_model_dir} \ - --config $callhome_diar_config \ - --ngpu $gpu_num \ - --num_worker_count $count \ - --dist_init_method $init_method \ - --dist_world_size $world_size \ - --dist_rank $rank \ - --local_rank $local_rank 1> ${exp_dir}/exp/${callhome_model_dir}/log/train.log.$i 2>&1 - } & - done - wait - echo "averaging model parameters into ${exp_dir}/exp/$callhome_model_dir/$callhome_ave_id.pb" - models=`eval echo ${exp_dir}/exp/${callhome_model_dir}/{$callhome_average_start..$callhome_average_end}epoch.pb` - python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models -fi - -# inference and compute DER -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - echo "Inference" - mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log - CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \ - --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \ - --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \ - --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \ - --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \ - 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1 - md-eval.pl -c 0.25 \ - -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \ - -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit -fi \ No newline at end of file From ecead60d25954e092ee46f79f17d7ee79da642ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Sat, 5 Aug 2023 17:55:06 +0800 Subject: [PATCH 41/42] update repo --- egs/callhome/eend_ola/run.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh index d44241234..aa441bfcb 100644 --- a/egs/callhome/eend_ola/run.sh +++ b/egs/callhome/eend_ola/run.sh @@ -222,10 +222,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models fi -# Training on simulated all-speaker data with chunk_size=2000 +# Training on simulated all-speaker data with chunk_size 2000 world_size=$gpu_num if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - echo "stage 3: Training on simulated all-speaker data with chunk_size=2000" + echo "stage 3: Training on simulated all-speaker data with chunk_size 2000" mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init @@ -263,11 +263,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then wait fi -# Training on callhome all-speaker data with chunk_size=2000 +# Training on callhome all-speaker data with chunk_size 2000 world_size=$gpu_num callhome_ave_id=avg${callhome_average_start}-${callhome_average_end} if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - echo "stage 4: Training on callhome all-speaker data with chunk_size=2000" + echo "stage 4: Training on callhome all-speaker data with chunk_size 2000" mkdir -p ${exp_dir}/exp/${callhome_model_dir} mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init From ade75d2987f157037ed2d3d8d80d94473337e0d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Mon, 7 Aug 2023 14:26:39 +0800 Subject: [PATCH 42/42] update repo --- funasr/models/e2e_diar_eend_ola.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 0225a7a4c..a0b545aac 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -6,7 +6,6 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from typeguard import check_argument_types from funasr.models.base_model import FunASRModel from funasr.models.frontend.wav_frontend import WavFrontendMel23 @@ -70,8 +69,6 @@ class DiarEENDOLAModel(FunASRModel): mapping_dict=None, **kwargs, ): - assert check_argument_types() - super().__init__() self.frontend = frontend self.enc = encoder