Merge pull request #806 from alibaba-damo-academy/dev_wjm_sd

update eend-ola
This commit is contained in:
jmwang66 2023-08-07 16:09:39 +08:00 committed by GitHub
commit 993f226f35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 8624 additions and 188 deletions

View File

@ -0,0 +1,45 @@
# network architecture
# encoder related
encoder: eend_ola_transformer
encoder_conf:
idim: 345
n_layers: 4
n_units: 256
# encoder-decoder attractor related
encoder_decoder_attractor: eda
encoder_decoder_attractor_conf:
n_units: 256
# model related
model: eend_ola
model_conf:
attractor_loss_weight: 0.01
max_n_speaker: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 100
optim: adam
optim_conf:
lr: 0.00001
dataset_conf:
data_names: speech_speaker_labels
data_types: kaldi_ark
batch_conf:
batch_type: unsorted
batch_size: 8
num_workers: 8
log_interval: 50

View File

@ -0,0 +1,52 @@
# network architecture
# encoder related
encoder: eend_ola_transformer
encoder_conf:
idim: 345
n_layers: 4
n_units: 256
# encoder-decoder attractor related
encoder_decoder_attractor: eda
encoder_decoder_attractor_conf:
n_units: 256
# model related
model: eend_ola
model_conf:
max_n_speaker: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 100
optim: adam
optim_conf:
lr: 1.0
betas:
- 0.9
- 0.98
eps: 1.0e-9
scheduler: noamlr
scheduler_conf:
model_size: 256
warmup_steps: 100000
dataset_conf:
data_names: speech_speaker_labels
data_types: kaldi_ark
batch_conf:
batch_type: unsorted
batch_size: 64
num_workers: 8
log_interval: 50

View File

@ -0,0 +1,52 @@
# network architecture
# encoder related
encoder: eend_ola_transformer
encoder_conf:
idim: 345
n_layers: 4
n_units: 256
# encoder-decoder attractor related
encoder_decoder_attractor: eda
encoder_decoder_attractor_conf:
n_units: 256
# model related
model: eend_ola
model_conf:
max_n_speaker: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 25
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 100
optim: adam
optim_conf:
lr: 1.0
betas:
- 0.9
- 0.98
eps: 1.0e-9
scheduler: noamlr
scheduler_conf:
model_size: 256
warmup_steps: 100000
dataset_conf:
data_names: speech_speaker_labels
data_types: kaldi_ark
batch_conf:
batch_type: unsorted
batch_size: 64
num_workers: 8
log_interval: 50

View File

@ -0,0 +1,44 @@
# network architecture
# encoder related
encoder: eend_ola_transformer
encoder_conf:
idim: 345
n_layers: 4
n_units: 256
# encoder-decoder attractor related
encoder_decoder_attractor: eda
encoder_decoder_attractor_conf:
n_units: 256
# model related
model: eend_ola
model_conf:
max_n_speaker: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 1
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 100
optim: adam
optim_conf:
lr: 0.00001
dataset_conf:
data_names: speech_speaker_labels
data_types: kaldi_ark
batch_conf:
batch_type: unsorted
batch_size: 8
num_workers: 8
log_interval: 50

View File

@ -0,0 +1,144 @@
import argparse
import os
from kaldiio import WriteHelper
import funasr.modules.eend_ola.utils.feature as feature
from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \
load_spk2utt, load_wav
def _count_frames(data_len, size, step):
return int((data_len - size + step) / step)
def _gen_frame_indices(
data_length, size=2000, step=2000,
use_last_samples=False,
label_delay=0,
subsampling=1):
i = -1
for i in range(_count_frames(data_length, size, step)):
yield i * step, i * step + size
if use_last_samples and i * step + size < data_length:
if data_length - (i + 1) * step - subsampling * label_delay > 0:
yield (i + 1) * step, data_length
class KaldiData:
def __init__(self, data_dir, idx):
self.data_dir = data_dir
segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx))
self.segments = load_segments_rechash(segment_file)
utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx))
self.utt2spk = load_utt2spk(utt2spk_file)
wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx))
self.wavs = load_wav_scp(wav_file)
reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx))
self.reco2dur = load_reco2dur(reco2dur_file)
spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx))
self.spk2utt = load_spk2utt(spk2utt_file)
def load_wav(self, recid, start=0, end=None):
data, rate = load_wav(self.wavs[recid], start, end)
return data, rate
class KaldiDiarizationDataset():
def __init__(
self,
data_dir,
index,
chunk_size=2000,
context_size=0,
frame_size=1024,
frame_shift=256,
subsampling=1,
rate=16000,
input_transform=None,
use_last_samples=False,
label_delay=0,
n_speakers=None,
):
self.data_dir = data_dir
self.index = index
self.chunk_size = chunk_size
self.context_size = context_size
self.frame_size = frame_size
self.frame_shift = frame_shift
self.subsampling = subsampling
self.input_transform = input_transform
self.n_speakers = n_speakers
self.chunk_indices = []
self.label_delay = label_delay
self.data = KaldiData(self.data_dir, index)
for rec, path in self.data.wavs.items():
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
data_len = int(data_len / self.subsampling)
for st, ed in _gen_frame_indices(
data_len, chunk_size, chunk_size, use_last_samples,
label_delay=self.label_delay,
subsampling=self.subsampling):
self.chunk_indices.append(
(rec, path, st * self.subsampling, ed * self.subsampling))
print(len(self.chunk_indices), " chunks")
def convert(args):
dataset = KaldiDiarizationDataset(
data_dir=args.data_dir,
index=args.index,
chunk_size=args.num_frames,
context_size=args.context_size,
input_transform="logmel23_mn",
frame_size=args.frame_size,
frame_shift=args.frame_shift,
subsampling=args.subsampling,
rate=8000,
use_last_samples=True,
)
feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index))
feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index))
label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index))
label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index))
with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \
WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer:
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
Y, T = feature.get_labeledSTFT(
dataset.data,
rec,
st,
ed,
dataset.frame_size,
dataset.frame_shift,
dataset.n_speakers)
Y = feature.transform(Y, dataset.input_transform)
Y_spliced = feature.splice(Y, dataset.context_size)
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
st = '{:0>7d}'.format(st)
ed = '{:0>7d}'.format(ed)
key = "{}_{}_{}".format(rec, st, ed)
feature_writer(key, Y_ss)
label_writer(key, T_ss.reshape(-1))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--index", type=str)
parser.add_argument("--num_frames", type=int, default=500)
parser.add_argument("--context_size", type=int, default=7)
parser.add_argument("--frame_size", type=int, default=200)
parser.add_argument("--frame_shift", type=int, default=80)
parser.add_argument("--subsampling", type=int, default=10)
args = parser.parse_args()
convert(args)

View File

@ -0,0 +1,25 @@
import os
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str)
parser.add_argument("--out_path", type=str)
parser.add_argument("--split_num", type=int, default=64)
args = parser.parse_args()
root_path = args.root_path
out_path = args.out_path
split_num = args.split_num
with open(os.path.join(out_path, "feats.scp"), "w") as out_f:
for i in range(split_num):
idx = str(i + 1)
feature_file = os.path.join(root_path, "feature.scp.{}".format(idx))
label_file = os.path.join(root_path, "label.scp.{}".format(idx))
with open(feature_file) as ff, open(label_file) as fl:
ff_lines = ff.readlines()
fl_lines = fl.readlines()
for ff_line, fl_line in zip(ff_lines, fl_lines):
sample_name, f_path = ff_line.strip().split()
_, l_path = fl_line.strip().split()
out_f.write("{} {} {}\n".format(sample_name, f_path, l_path))

View File

@ -0,0 +1,138 @@
import argparse
import os
import numpy as np
import soundfile as sf
import torch
import yaml
from scipy.signal import medfilt
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.build_utils.build_model_from_file import build_model_from_file
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
type=str,
help="model config file",
)
parser.add_argument(
"--model_file",
type=str,
help="model path",
)
parser.add_argument(
"--output_rttm_file",
type=str,
help="output rttm path",
)
parser.add_argument(
"--wav_scp_file",
type=str,
default="wav.scp",
help="input data path",
)
parser.add_argument(
"--frame_shift",
type=int,
default=80,
help="frame shift",
)
parser.add_argument(
"--frame_size",
type=int,
default=200,
help="frame size",
)
parser.add_argument(
"--context_size",
type=int,
default=7,
help="context size",
)
parser.add_argument(
"--sampling_rate",
type=int,
default=8000,
help="sampling rate",
)
parser.add_argument(
"--subsampling",
type=int,
default=10,
help="setting subsampling",
)
parser.add_argument(
"--shuffle",
type=bool,
default=True,
help="shuffle speech in time",
)
parser.add_argument(
"--attractor_threshold",
type=float,
default=0.5,
help="threshold for selecting attractors",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
args = parser.parse_args()
with open(args.config_file) as f:
configs = yaml.safe_load(f)
for k, v in configs.items():
if not hasattr(args, k):
setattr(args, k, v)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
os.environ['PYTORCH_SEED'] = str(args.seed)
model, _ = build_model_from_file(config_file=args.config_file, model_file=args.model_file, task_name="diar",
device=args.device)
model.eval()
with open(args.wav_scp_file) as f:
wav_lines = [line.strip().split() for line in f.readlines()]
wav_items = {x[0]: x[1] for x in wav_lines}
print("Start inference")
with open(args.output_rttm_file, "w") as wf:
for wav_id in wav_items.keys():
print("Process wav: {}".format(wav_id))
data, rate = sf.read(wav_items[wav_id])
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
speech = eend_ola_feature.transform(speech)
speech = eend_ola_feature.splice(speech, context_size=args.context_size)
speech = speech[::args.subsampling] # sampling
speech = torch.from_numpy(speech)
with torch.no_grad():
speech = speech.to(args.device)
ys, _, _, _ = model.estimate_sequential(
[speech],
n_speakers=None,
th=args.attractor_threshold,
shuffle=args.shuffle
)
a = ys[0].cpu().numpy()
a = medfilt(a, (11, 1))
rst = []
for spkr_id, frames in enumerate(a.T):
frames = np.pad(frames, (1, 1), 'constant')
changes, = np.where(np.diff(frames, axis=0) != 0)
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
st = s * args.frame_shift * args.subsampling / args.sampling_rate
dur = (e - s) * args.frame_shift * args.subsampling / args.sampling_rate
print(fmt.format(
wav_id,
st,
dur,
wav_id + "_" + str(spkr_id)), file=wf)

View File

@ -0,0 +1,73 @@
#!/bin/bash
# Copyright 2017 David Snyder
# Apache 2.0.
#
# This script prepares the Callhome portion of the NIST SRE 2000
# corpus (LDC2001S97). It is the evaluation dataset used in the
# callhome_diarization recipe.
if [ $# -ne 2 ]; then
echo "Usage: $0 <callhome-speech> <out-data-dir>"
echo "e.g.: $0 /mnt/data/LDC2001S97 data/"
exit 1;
fi
src_dir=$1
data_dir=$2
tmp_dir=$data_dir/callhome/.tmp/
mkdir -p $tmp_dir
# Download some metadata that wasn't provided in the LDC release
if [ ! -d "$tmp_dir/sre2000-key" ]; then
wget --no-check-certificate -P $tmp_dir/ \
http://www.openslr.org/resources/10/sre2000-key.tar.gz
tar -xvf $tmp_dir/sre2000-key.tar.gz -C $tmp_dir/
fi
# The list of 500 recordings
awk '{print $1}' $tmp_dir/sre2000-key/reco2num > $tmp_dir/reco.list
# Create wav.scp file
count=0
missing=0
while read reco; do
path=$(find $src_dir -name "$reco.sph")
if [ -z "${path// }" ]; then
>&2 echo "$0: Missing Sphere file for $reco"
missing=$((missing+1))
else
echo "$reco sph2pipe -f wav -p $path |"
fi
count=$((count+1))
done < $tmp_dir/reco.list > $data_dir/callhome/wav.scp
if [ $missing -gt 0 ]; then
echo "$0: Missing $missing out of $count recordings"
fi
cp $tmp_dir/sre2000-key/segments $data_dir/callhome/
awk '{print $1, $2}' $data_dir/callhome/segments > $data_dir/callhome/utt2spk
utils/utt2spk_to_spk2utt.pl $data_dir/callhome/utt2spk > $data_dir/callhome/spk2utt
cp $tmp_dir/sre2000-key/reco2num $data_dir/callhome/reco2num_spk
cp $tmp_dir/sre2000-key/fullref.rttm $data_dir/callhome/
utils/validate_data_dir.sh --no-text --no-feats $data_dir/callhome
utils/fix_data_dir.sh $data_dir/callhome
utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome1
utils/copy_data_dir.sh $data_dir/callhome $data_dir/callhome2
utils/shuffle_list.pl $data_dir/callhome/wav.scp | head -n 250 \
| utils/filter_scp.pl - $data_dir/callhome/wav.scp \
> $data_dir/callhome1/wav.scp
utils/fix_data_dir.sh $data_dir/callhome1
utils/filter_scp.pl --exclude $data_dir/callhome1/wav.scp \
$data_dir/callhome/wav.scp > $data_dir/callhome2/wav.scp
utils/fix_data_dir.sh $data_dir/callhome2
utils/filter_scp.pl $data_dir/callhome1/wav.scp $data_dir/callhome/reco2num_spk \
> $data_dir/callhome1/reco2num_spk
utils/filter_scp.pl $data_dir/callhome2/wav.scp $data_dir/callhome/reco2num_spk \
> $data_dir/callhome2/reco2num_spk
rm -rf $tmp_dir 2> /dev/null

View File

@ -0,0 +1,120 @@
#!/usr/bin/env python3
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This script generates simulated multi-talker mixtures for diarization
#
# common/make_mixture.py \
# mixture.scp \
# data/mixture \
# wav/mixture
import argparse
import os
from funasr.modules.eend_ola.utils import kaldi_data
import numpy as np
import math
import soundfile as sf
import json
parser = argparse.ArgumentParser()
parser.add_argument('script',
help='list of json')
parser.add_argument('out_data_dir',
help='output data dir of mixture')
parser.add_argument('out_wav_dir',
help='output mixture wav files are stored here')
parser.add_argument('--rate', type=int, default=16000,
help='sampling rate')
args = parser.parse_args()
# open output data files
segments_f = open(args.out_data_dir + '/segments', 'w')
utt2spk_f = open(args.out_data_dir + '/utt2spk', 'w')
wav_scp_f = open(args.out_data_dir + '/wav.scp', 'w')
# "-R" forces the default random seed for reproducibility
resample_cmd = "sox -R -t wav - -t wav - rate {}".format(args.rate)
for line in open(args.script):
recid, jsonstr = line.strip().split(None, 1)
indata = json.loads(jsonstr)
wavfn = indata['recid']
# recid now include out_wav_dir
recid = os.path.join(args.out_wav_dir, wavfn).replace('/','_')
noise = indata['noise']
noise_snr = indata['snr']
mixture = []
for speaker in indata['speakers']:
spkid = speaker['spkid']
utts = speaker['utts']
intervals = speaker['intervals']
rir = speaker['rir']
data = []
pos = 0
for interval, utt in zip(intervals, utts):
# append silence interval data
silence = np.zeros(int(interval * args.rate))
data.append(silence)
# utterance is reverberated using room impulse response
preprocess = "wav-reverberate --print-args=false " \
" --impulse-response={} - -".format(rir)
if isinstance(utt, list):
rec, st, et = utt
st = np.rint(st * args.rate).astype(int)
et = np.rint(et * args.rate).astype(int)
else:
rec = utt
st = 0
et = None
if rir is not None:
wav_rxfilename = kaldi_data.process_wav(rec, preprocess)
else:
wav_rxfilename = rec
wav_rxfilename = kaldi_data.process_wav(
wav_rxfilename, resample_cmd)
speech, _ = kaldi_data.load_wav(wav_rxfilename, st, et)
data.append(speech)
# calculate start/end position in samples
startpos = pos + len(silence)
endpos = startpos + len(speech)
# write segments and utt2spk
uttid = '{}_{}_{:07d}_{:07d}'.format(
spkid, recid, int(startpos / args.rate * 100),
int(endpos / args.rate * 100))
print(uttid, recid,
startpos / args.rate, endpos / args.rate, file=segments_f)
print(uttid, spkid, file=utt2spk_f)
# update position for next utterance
pos = endpos
data = np.concatenate(data)
mixture.append(data)
# fitting to the maximum-length speaker data, then mix all speakers
maxlen = max(len(x) for x in mixture)
mixture = [np.pad(x, (0, maxlen - len(x)), 'constant') for x in mixture]
mixture = np.sum(mixture, axis=0)
# noise is repeated or cutted for fitting to the mixture data length
noise_resampled = kaldi_data.process_wav(noise, resample_cmd)
noise_data, _ = kaldi_data.load_wav(noise_resampled)
if maxlen > len(noise_data):
noise_data = np.pad(noise_data, (0, maxlen - len(noise_data)), 'wrap')
else:
noise_data = noise_data[:maxlen]
# noise power is scaled according to selected SNR, then mixed
signal_power = np.sum(mixture**2) / len(mixture)
noise_power = np.sum(noise_data**2) / len(noise_data)
scale = math.sqrt(
math.pow(10, - noise_snr / 10) * signal_power / noise_power)
mixture += noise_data * scale
# output the wav file and write wav.scp
outfname = '{}.wav'.format(wavfn)
outpath = os.path.join(args.out_wav_dir, outfname)
sf.write(outpath, mixture, args.rate)
print(recid, os.path.abspath(outpath), file=wav_scp_f)
wav_scp_f.close()
segments_f.close()
utt2spk_f.close()

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2015 David Snyder
# 2018 Ewald Enzinger
# Apache 2.0.
#
# Modified version of egs/sre16/v1/local/make_musan.py (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8).
# This version uses the raw MUSAN audio files (16 kHz) and does not use sox to resample at 8 kHz.
#
# This file is meant to be invoked by make_musan.sh.
import os, sys
def process_music_annotations(path):
utt2spk = {}
utt2vocals = {}
lines = open(path, 'r').readlines()
for line in lines:
utt, genres, vocals, musician = line.rstrip().split()[:4]
# For this application, the musican ID isn't important
utt2spk[utt] = utt
utt2vocals[utt] = vocals == "Y"
return utt2spk, utt2vocals
def prepare_music(root_dir, use_vocals):
utt2vocals = {}
utt2spk = {}
utt2wav = {}
num_good_files = 0
num_bad_files = 0
music_dir = os.path.join(root_dir, "music")
for root, dirs, files in os.walk(music_dir):
for file in files:
file_path = os.path.join(root, file)
if file.endswith(".wav"):
utt = str(file).replace(".wav", "")
utt2wav[utt] = file_path
elif str(file) == "ANNOTATIONS":
utt2spk_part, utt2vocals_part = process_music_annotations(file_path)
utt2spk.update(utt2spk_part)
utt2vocals.update(utt2vocals_part)
utt2spk_str = ""
utt2wav_str = ""
for utt in utt2vocals:
if utt in utt2wav:
if use_vocals or not utt2vocals[utt]:
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
num_good_files += 1
else:
print("Missing file {}".format(utt))
num_bad_files += 1
print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
return utt2spk_str, utt2wav_str
def prepare_speech(root_dir):
utt2spk = {}
utt2wav = {}
num_good_files = 0
num_bad_files = 0
speech_dir = os.path.join(root_dir, "speech")
for root, dirs, files in os.walk(speech_dir):
for file in files:
file_path = os.path.join(root, file)
if file.endswith(".wav"):
utt = str(file).replace(".wav", "")
utt2wav[utt] = file_path
utt2spk[utt] = utt
utt2spk_str = ""
utt2wav_str = ""
for utt in utt2spk:
if utt in utt2wav:
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
num_good_files += 1
else:
print("Missing file {}".format(utt))
num_bad_files += 1
print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
return utt2spk_str, utt2wav_str
def prepare_noise(root_dir):
utt2spk = {}
utt2wav = {}
num_good_files = 0
num_bad_files = 0
noise_dir = os.path.join(root_dir, "noise")
for root, dirs, files in os.walk(noise_dir):
for file in files:
file_path = os.path.join(root, file)
if file.endswith(".wav"):
utt = str(file).replace(".wav", "")
utt2wav[utt] = file_path
utt2spk[utt] = utt
utt2spk_str = ""
utt2wav_str = ""
for utt in utt2spk:
if utt in utt2wav:
utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n"
utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n"
num_good_files += 1
else:
print("Missing file {}".format(utt))
num_bad_files += 1
print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files))
return utt2spk_str, utt2wav_str
def main():
in_dir = sys.argv[1]
out_dir = sys.argv[2]
use_vocals = sys.argv[3] == "Y"
utt2spk_music, utt2wav_music = prepare_music(in_dir, use_vocals)
utt2spk_speech, utt2wav_speech = prepare_speech(in_dir)
utt2spk_noise, utt2wav_noise = prepare_noise(in_dir)
utt2spk = utt2spk_speech + utt2spk_music + utt2spk_noise
utt2wav = utt2wav_speech + utt2wav_music + utt2wav_noise
wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w')
wav_fi.write(utt2wav)
utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w')
utt2spk_fi.write(utt2spk)
if __name__=="__main__":
main()

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2015 David Snyder
# Apache 2.0.
#
# This script, called by ../run.sh, creates the MUSAN
# data directory. The required dataset is freely available at
# http://www.openslr.org/17/
set -e
in_dir=$1
data_dir=$2
use_vocals='Y'
mkdir -p local/musan.tmp
echo "Preparing ${data_dir}/musan..."
mkdir -p ${data_dir}/musan
local/make_musan.py ${in_dir} ${data_dir}/musan ${use_vocals}
utils/fix_data_dir.sh ${data_dir}/musan
grep "music" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_music
grep "speech" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_speech
grep "noise" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_noise
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_music \
${data_dir}/musan ${data_dir}/musan_music
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_speech \
${data_dir}/musan ${data_dir}/musan_speech
utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_noise \
${data_dir}/musan ${data_dir}/musan_noise
utils/fix_data_dir.sh ${data_dir}/musan_music
utils/fix_data_dir.sh ${data_dir}/musan_speech
utils/fix_data_dir.sh ${data_dir}/musan_noise
rm -rf local/musan.tmp

View File

@ -0,0 +1,63 @@
#!/usr/bin/perl
#
# Copyright 2015 David Snyder
# Apache 2.0.
# Usage: make_sre.pl <path-to-data> <name-of-source> <sre-ref> <output-dir>
if (@ARGV != 4) {
print STDERR "Usage: $0 <path-to-data> <name-of-source> <sre-ref> <output-dir>\n";
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2006S44 sre2004 sre_ref data/sre2004\n";
exit(1);
}
($db_base, $sre_name, $sre_ref_filename, $out_dir) = @ARGV;
%utt2sph = ();
%spk2gender = ();
$tmp_dir = "$out_dir/tmp";
if (system("mkdir -p $tmp_dir") != 0) {
die "Error making directory $tmp_dir";
}
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
die "Error getting list of sph files";
}
open(WAVLIST, "<", "$tmp_dir/sph.list") or die "cannot open wav list";
while(<WAVLIST>) {
chomp;
$sph = $_;
@A1 = split("/",$sph);
@A2 = split("[./]",$A1[$#A1]);
$uttId=$A2[0];
$utt2sph{$uttId} = $sph;
}
open(GNDR,">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender";
open(SPKR,">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk";
open(WAV,">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp";
open(SRE_REF, "<", $sre_ref_filename) or die "Cannot open SRE reference.";
while (<SRE_REF>) {
chomp;
($speaker, $gender, $other_sre_name, $utt_id, $channel) = split(" ", $_);
$channel_num = "1";
if ($channel eq "A") {
$channel_num = "1";
} else {
$channel_num = "2";
}
if (($other_sre_name eq $sre_name) and (exists $utt2sph{$utt_id})) {
$full_utt_id = "$speaker-$gender-$sre_name-$utt_id-$channel";
$spk2gender{"$speaker-$gender"} = $gender;
print WAV "$full_utt_id"," sph2pipe -f wav -p -c $channel_num $utt2sph{$utt_id} |\n";
print SPKR "$full_utt_id $speaker-$gender","\n";
}
}
foreach $speaker (keys %spk2gender) {
print GNDR "$speaker $spk2gender{$speaker}\n";
}
close(GNDR) || die;
close(SPKR) || die;
close(WAV) || die;
close(SRE_REF) || die;

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2015 David Snyder
# Apache 2.0.
#
# See README.txt for more info on data required.
set -e
data_root=$1
data_dir=$2
wget -P data/local/ http://www.openslr.org/resources/15/speaker_list.tgz
tar -C data/local/ -xvf data/local/speaker_list.tgz
sre_ref=data/local/speaker_list
local/make_sre.pl $data_root/LDC2006S44/ \
sre2004 $sre_ref $data_dir/sre2004
local/make_sre.pl $data_root/LDC2011S01 \
sre2005 $sre_ref $data_dir/sre2005_train
local/make_sre.pl $data_root/LDC2011S04 \
sre2005 $sre_ref $data_dir/sre2005_test
local/make_sre.pl $data_root/LDC2011S09 \
sre2006 $sre_ref $data_dir/sre2006_train
local/make_sre.pl $data_root/LDC2011S10 \
sre2006 $sre_ref $data_dir/sre2006_test_1
local/make_sre.pl $data_root/LDC2012S01 \
sre2006 $sre_ref $data_dir/sre2006_test_2
local/make_sre.pl $data_root/LDC2011S05 \
sre2008 $sre_ref $data_dir/sre2008_train
local/make_sre.pl $data_root/LDC2011S08 \
sre2008 $sre_ref $data_dir/sre2008_test
utils/combine_data.sh $data_dir/sre \
$data_dir/sre2004 $data_dir/sre2005_train \
$data_dir/sre2005_test $data_dir/sre2006_train \
$data_dir/sre2006_test_1 $data_dir/sre2006_test_2 \
$data_dir/sre2008_train $data_dir/sre2008_test
utils/validate_data_dir.sh --no-text --no-feats $data_dir/sre
utils/fix_data_dir.sh $data_dir/sre
rm data/local/speaker_list.*

View File

@ -0,0 +1,106 @@
#!/usr/bin/perl
use warnings; #sed replacement for -w perl parameter
#
# Copyright 2017 David Snyder
# Apache 2.0
if (@ARGV != 2) {
print STDERR "Usage: $0 <path-to-LDC98S75> <path-to-output>\n";
print STDERR "e.g. $0 /export/corpora3/LDC/LDC98S75 data/swbd2_phase1_train\n";
exit(1);
}
($db_base, $out_dir) = @ARGV;
if (system("mkdir -p $out_dir")) {
die "Error making directory $out_dir";
}
open(CS, "<$db_base/doc/callstat.tbl") || die "Could not open $db_base/doc/callstat.tbl";
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
@badAudio = ("3", "4");
$tmp_dir = "$out_dir/tmp";
if (system("mkdir -p $tmp_dir") != 0) {
die "Error making directory $tmp_dir";
}
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
die "Error getting list of sph files";
}
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
%wavs = ();
while(<WAVLIST>) {
chomp;
$sph = $_;
@t = split("/",$sph);
@t1 = split("[./]",$t[$#t]);
$uttId = $t1[0];
$wavs{$uttId} = $sph;
}
while (<CS>) {
$line = $_ ;
@A = split(",", $line);
@A1 = split("[./]",$A[0]);
$wav = $A1[0];
if (/$wav/i ~~ @badAudio) {
# do nothing
print "Bad Audio = $wav";
} else {
$spkr1= "sw_" . $A[2];
$spkr2= "sw_" . $A[3];
$gender1 = $A[5];
$gender2 = $A[6];
if ($gender1 eq "M") {
$gender1 = "m";
} elsif ($gender1 eq "F") {
$gender1 = "f";
} else {
die "Unknown Gender in $line";
}
if ($gender2 eq "M") {
$gender2 = "m";
} elsif ($gender2 eq "F") {
$gender2 = "f";
} else {
die "Unknown Gender in $line";
}
if (-e "$wavs{$wav}") {
$uttId = $spkr1 ."_" . $wav ."_1";
if (!$spk2gender{$spkr1}) {
$spk2gender{$spkr1} = $gender1;
print GNDR "$spkr1"," $gender1\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wavs{$wav} |\n";
print SPKR "$uttId"," $spkr1","\n";
$uttId = $spkr2 . "_" . $wav ."_2";
if (!$spk2gender{$spkr2}) {
$spk2gender{$spkr2} = $gender2;
print GNDR "$spkr2"," $gender2\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wavs{$wav} |\n";
print SPKR "$uttId"," $spkr2","\n";
} else {
print STDERR "Missing $wavs{$wav} for $wav\n";
}
}
}
close(WAV) || die;
close(SPKR) || die;
close(GNDR) || die;
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
die "Error fixing data dir $out_dir";
}
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}

View File

@ -0,0 +1,107 @@
#!/usr/bin/perl
use warnings; #sed replacement for -w perl parameter
#
# Copyright 2013 Daniel Povey
# Apache 2.0
if (@ARGV != 2) {
print STDERR "Usage: $0 <path-to-LDC99S79> <path-to-output>\n";
print STDERR "e.g. $0 /export/corpora5/LDC/LDC99S79 data/swbd2_phase2_train\n";
exit(1);
}
($db_base, $out_dir) = @ARGV;
if (system("mkdir -p $out_dir")) {
die "Error making directory $out_dir";
}
open(CS, "<$db_base/DISC1/doc/callstat.tbl") || die "Could not open $db_base/DISC1/doc/callstat.tbl";
open(CI, "<$db_base/DISC1/doc/callinfo.tbl") || die "Could not open $db_base/DISC1/doc/callinfo.tbl";
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
@badAudio = ("3", "4");
$tmp_dir = "$out_dir/tmp";
if (system("mkdir -p $tmp_dir") != 0) {
die "Error making directory $tmp_dir";
}
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
die "Error getting list of sph files";
}
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
while(<WAVLIST>) {
chomp;
$sph = $_;
@t = split("/",$sph);
@t1 = split("[./]",$t[$#t]);
$uttId=$t1[0];
$wav{$uttId} = $sph;
}
while (<CS>) {
$line = $_ ;
$ci = <CI>;
$ci = <CI>;
@ci = split(",",$ci);
$wav = $ci[0];
@A = split(",", $line);
if (/$wav/i ~~ @badAudio) {
# do nothing
} else {
$spkr1= "sw_" . $A[2];
$spkr2= "sw_" . $A[3];
$gender1 = $A[4];
$gender2 = $A[5];
if ($gender1 eq "M") {
$gender1 = "m";
} elsif ($gender1 eq "F") {
$gender1 = "f";
} else {
die "Unknown Gender in $line";
}
if ($gender2 eq "M") {
$gender2 = "m";
} elsif ($gender2 eq "F") {
$gender2 = "f";
} else {
die "Unknown Gender in $line";
}
if (-e "$wav{$wav}") {
$uttId = $spkr1 ."_" . $wav ."_1";
if (!$spk2gender{$spkr1}) {
$spk2gender{$spkr1} = $gender1;
print GNDR "$spkr1"," $gender1\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n";
print SPKR "$uttId"," $spkr1","\n";
$uttId = $spkr2 . "_" . $wav ."_2";
if (!$spk2gender{$spkr2}) {
$spk2gender{$spkr2} = $gender2;
print GNDR "$spkr2"," $gender2\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n";
print SPKR "$uttId"," $spkr2","\n";
} else {
print STDERR "Missing $wav{$wav} for $wav\n";
}
}
}
close(WAV) || die;
close(SPKR) || die;
close(GNDR) || die;
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
die "Error fixing data dir $out_dir";
}
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}

View File

@ -0,0 +1,102 @@
#!/usr/bin/perl
use warnings; #sed replacement for -w perl parameter
#
# Copyright 2013 Daniel Povey
# Apache 2.0
if (@ARGV != 2) {
print STDERR "Usage: $0 <path-to-LDC2002S06> <path-to-output>\n";
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2002S06 data/swbd2_phase3_train\n";
exit(1);
}
($db_base, $out_dir) = @ARGV;
if (system("mkdir -p $out_dir")) {
die "Error making directory $out_dir";
}
open(CS, "<$db_base/DISC1/docs/callstat.tbl") || die "Could not open $db_base/DISC1/docs/callstat.tbl";
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
@badAudio = ("3", "4");
$tmp_dir = "$out_dir/tmp";
if (system("mkdir -p $tmp_dir") != 0) {
die "Error making directory $tmp_dir";
}
if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) {
die "Error getting list of sph files";
}
open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list";
while(<WAVLIST>) {
chomp;
$sph = $_;
@t = split("/",$sph);
@t1 = split("[./]",$t[$#t]);
$uttId=$t1[0];
$wav{$uttId} = $sph;
}
while (<CS>) {
$line = $_ ;
@A = split(",", $line);
$wav = "sw_" . $A[0] ;
if (/$wav/i ~~ @badAudio) {
# do nothing
} else {
$spkr1= "sw_" . $A[3];
$spkr2= "sw_" . $A[4];
$gender1 = $A[5];
$gender2 = $A[6];
if ($gender1 eq "M") {
$gender1 = "m";
} elsif ($gender1 eq "F") {
$gender1 = "f";
} else {
die "Unknown Gender in $line";
}
if ($gender2 eq "M") {
$gender2 = "m";
} elsif ($gender2 eq "F") {
$gender2 = "f";
} else {
die "Unknown Gender in $line";
}
if (-e "$wav{$wav}") {
$uttId = $spkr1 ."_" . $wav ."_1";
if (!$spk2gender{$spkr1}) {
$spk2gender{$spkr1} = $gender1;
print GNDR "$spkr1"," $gender1\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wav{$wav} |\n";
print SPKR "$uttId"," $spkr1","\n";
$uttId = $spkr2 . "_" . $wav ."_2";
if (!$spk2gender{$spkr2}) {
$spk2gender{$spkr2} = $gender2;
print GNDR "$spkr2"," $gender2\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wav{$wav} |\n";
print SPKR "$uttId"," $spkr2","\n";
} else {
print STDERR "Missing $wav{$wav} for $wav\n";
}
}
}
close(WAV) || die;
close(SPKR) || die;
close(GNDR) || die;
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
die "Error fixing data dir $out_dir";
}
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}

View File

@ -0,0 +1,83 @@
#!/usr/bin/perl
use warnings; #sed replacement for -w perl parameter
#
# Copyright 2013 Daniel Povey
# Apache 2.0
if (@ARGV != 2) {
print STDERR "Usage: $0 <path-to-LDC2001S13> <path-to-output>\n";
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2001S13 data/swbd_cellular1_train\n";
exit(1);
}
($db_base, $out_dir) = @ARGV;
if (system("mkdir -p $out_dir")) {
die "Error making directory $out_dir";
}
open(CS, "<$db_base/doc/swb_callstats.tbl") || die "Could not open $db_base/doc/swb_callstats.tbl";
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
@badAudio = ("40019", "45024", "40022");
while (<CS>) {
$line = $_ ;
@A = split(",", $line);
if (/$A[0]/i ~~ @badAudio) {
# do nothing
} else {
$wav = "sw_" . $A[0];
$spkr1= "sw_" . $A[1];
$spkr2= "sw_" . $A[2];
$gender1 = $A[3];
$gender2 = $A[4];
if ($A[3] eq "M") {
$gender1 = "m";
} elsif ($A[3] eq "F") {
$gender1 = "f";
} else {
die "Unknown Gender in $line";
}
if ($A[4] eq "M") {
$gender2 = "m";
} elsif ($A[4] eq "F") {
$gender2 = "f";
} else {
die "Unknown Gender in $line";
}
if (-e "$db_base/data/$wav.sph") {
$uttId = $spkr1 . "-swbdc_" . $wav ."_1";
if (!$spk2gender{$spkr1}) {
$spk2gender{$spkr1} = $gender1;
print GNDR "$spkr1"," $gender1\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n";
print SPKR "$uttId"," $spkr1","\n";
$uttId = $spkr2 . "-swbdc_" . $wav ."_2";
if (!$spk2gender{$spkr2}) {
$spk2gender{$spkr2} = $gender2;
print GNDR "$spkr2"," $gender2\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n";
print SPKR "$uttId"," $spkr2","\n";
} else {
print STDERR "Missing $db_base/data/$wav.sph\n";
}
}
}
close(WAV) || die;
close(SPKR) || die;
close(GNDR) || die;
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
die "Error fixing data dir $out_dir";
}
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}

View File

@ -0,0 +1,83 @@
#!/usr/bin/perl
use warnings; #sed replacement for -w perl parameter
#
# Copyright 2013 Daniel Povey
# Apache 2.0
if (@ARGV != 2) {
print STDERR "Usage: $0 <path-to-LDC2004S07> <path-to-output>\n";
print STDERR "e.g. $0 /export/corpora5/LDC/LDC2004S07 data/swbd_cellular2_train\n";
exit(1);
}
($db_base, $out_dir) = @ARGV;
if (system("mkdir -p $out_dir")) {
die "Error making directory $out_dir";
}
open(CS, "<$db_base/docs/swb_callstats.tbl") || die "Could not open $db_base/docs/swb_callstats.tbl";
open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender";
open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp";
@badAudio=("45024", "40022");
while (<CS>) {
$line = $_ ;
@A = split(",", $line);
if (/$A[0]/i ~~ @badAudio) {
# do nothing
} else {
$wav = "sw_" . $A[0];
$spkr1= "sw_" . $A[1];
$spkr2= "sw_" . $A[2];
$gender1 = $A[3];
$gender2 = $A[4];
if ($A[3] eq "M") {
$gender1 = "m";
} elsif ($A[3] eq "F") {
$gender1 = "f";
} else {
die "Unknown Gender in $line";
}
if ($A[4] eq "M") {
$gender2 = "m";
} elsif ($A[4] eq "F") {
$gender2 = "f";
} else {
die "Unknown Gender in $line";
}
if (-e "$db_base/data/$wav.sph") {
$uttId = $spkr1 . "-swbdc_" . $wav ."_1";
if (!$spk2gender{$spkr1}) {
$spk2gender{$spkr1} = $gender1;
print GNDR "$spkr1"," $gender1\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 1 $db_base/data/$wav.sph |\n";
print SPKR "$uttId"," $spkr1","\n";
$uttId = $spkr2 . "-swbdc_" . $wav ."_2";
if (!$spk2gender{$spkr2}) {
$spk2gender{$spkr2} = $gender2;
print GNDR "$spkr2"," $gender2\n";
}
print WAV "$uttId"," sph2pipe -f wav -p -c 2 $db_base/data/$wav.sph |\n";
print SPKR "$uttId"," $spkr2","\n";
} else {
print STDERR "Missing $db_base/data/$wav.sph\n";
}
}
}
close(WAV) || die;
close(SPKR) || die;
close(GNDR) || die;
if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
die "Error creating spk2utt file in directory $out_dir";
}
if (system("utils/fix_data_dir.sh $out_dir") != 0) {
die "Error fixing data dir $out_dir";
}
if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
die "Error validating directory $out_dir";
}

View File

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import argparse
import torch
def average_model(input_files, output_file):
output_model = {}
for ckpt_path in input_files:
model_params = torch.load(ckpt_path, map_location="cpu")
for key, value in model_params.items():
if key not in output_model:
output_model[key] = value
else:
output_model[key] += value
for key in output_model.keys():
output_model[key] /= len(input_files)
torch.save(output_model, output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("output_file")
parser.add_argument("input_files", nargs='+')
args = parser.parse_args()
average_model(args.input_files, args.output_file)

View File

@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
"""
This script generates random multi-talker mixtures for diarization.
It generates a scp-like outputs: lines of "[recid] [json]".
recid: recording id of mixture
serial numbers like mix_0000001, mix_0000002, ...
json: mixture configuration formatted in "one-line"
The json format is as following:
{
'speakers':[ # list of speakers
{
'spkid': 'Name', # speaker id
'rir': '/rirdir/rir.wav', # wav_rxfilename of room impulse response
'utts': [ # list of wav_rxfilenames of utterances
'/wavdir/utt1.wav',
'/wavdir/utt2.wav',...],
'intervals': [1.2, 3.4, ...] # list of silence durations before utterances
}, ... ],
'noise': '/noisedir/noise.wav' # wav_rxfilename of background noise
'snr': 15.0, # SNR for mixing background noise
'recid': 'mix_000001' # recording id of the mixture
}
Usage:
common/random_mixture.py \
--n_mixtures=10000 \ # number of mixtures
data/voxceleb1_train \ # kaldi-style data dir of utterances
data/musan_noise_bg \ # background noises
data/simu_rirs \ # room impulse responses
> mixture.scp # output scp-like file
The actual data dir and wav files are generated using make_mixture.py:
common/make_mixture.py \
mixture.scp \ # scp-like file for mixture
data/mixture \ # output data dir
wav/mixture # output wav dir
"""
import argparse
import os
from funasr.modules.eend_ola.utils import kaldi_data
import random
import numpy as np
import json
import itertools
parser = argparse.ArgumentParser()
parser.add_argument('data_dir',
help='data dir of single-speaker recordings')
parser.add_argument('noise_dir',
help='data dir of background noise recordings')
parser.add_argument('rir_dir',
help='data dir of room impulse responses')
parser.add_argument('--n_mixtures', type=int, default=10,
help='number of mixture recordings')
parser.add_argument('--n_speakers', type=int, default=4,
help='number of speakers in a mixture')
parser.add_argument('--min_utts', type=int, default=10,
help='minimum number of uttenraces per speaker')
parser.add_argument('--max_utts', type=int, default=20,
help='maximum number of utterances per speaker')
parser.add_argument('--sil_scale', type=float, default=10.0,
help='average silence time')
parser.add_argument('--noise_snrs', default="10:15:20",
help='colon-delimited SNRs for background noises')
parser.add_argument('--random_seed', type=int, default=777,
help='random seed')
parser.add_argument('--speech_rvb_probability', type=float, default=1,
help='reverb probability')
args = parser.parse_args()
random.seed(args.random_seed)
np.random.seed(args.random_seed)
# load list of wav files from kaldi-style data dirs
wavs = kaldi_data.load_wav_scp(
os.path.join(args.data_dir, 'wav.scp'))
noises = kaldi_data.load_wav_scp(
os.path.join(args.noise_dir, 'wav.scp'))
rirs = kaldi_data.load_wav_scp(
os.path.join(args.rir_dir, 'wav.scp'))
# spk2utt is used for counting number of utterances per speaker
spk2utt = kaldi_data.load_spk2utt(
os.path.join(args.data_dir, 'spk2utt'))
segments = kaldi_data.load_segments_hash(
os.path.join(args.data_dir, 'segments'))
# choice lists for random sampling
all_speakers = list(spk2utt.keys())
all_noises = list(noises.keys())
all_rirs = list(rirs.keys())
noise_snrs = [float(x) for x in args.noise_snrs.split(':')]
mixtures = []
for it in range(args.n_mixtures):
# recording ids are mix_0000001, mix_0000002, ...
recid = 'mix_{:07d}'.format(it + 1)
# randomly select speakers, a background noise and a SNR
speakers = random.sample(all_speakers, args.n_speakers)
noise = random.choice(all_noises)
noise_snr = random.choice(noise_snrs)
mixture = {'speakers': []}
for speaker in speakers:
# randomly select the number of utterances
n_utts = np.random.randint(args.min_utts, args.max_utts + 1)
# utts = spk2utt[speaker][:n_utts]
cycle_utts = itertools.cycle(spk2utt[speaker])
# random start utterance
roll = np.random.randint(0, len(spk2utt[speaker]))
for i in range(roll):
next(cycle_utts)
utts = [next(cycle_utts) for i in range(n_utts)]
# randomly select wait time before appending utterance
intervals = np.random.exponential(args.sil_scale, size=n_utts)
# randomly select a room impulse response
if random.random() < args.speech_rvb_probability:
rir = rirs[random.choice(all_rirs)]
else:
rir = None
if segments is not None:
utts = [segments[utt] for utt in utts]
utts = [(wavs[rec], st, et) for (rec, st, et) in utts]
mixture['speakers'].append({
'spkid': speaker,
'rir': rir,
'utts': utts,
'intervals': intervals.tolist()
})
else:
mixture['speakers'].append({
'spkid': speaker,
'rir': rir,
'utts': [wavs[utt] for utt in utts],
'intervals': intervals.tolist()
})
mixture['noise'] = noises[noise]
mixture['snr'] = noise_snr
mixture['recid'] = recid
print(recid, json.dumps(mixture))

View File

@ -0,0 +1,235 @@
#!/bin/bash
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi)
# Licensed under the MIT license.
#
# This script prepares kaldi-style data sets shared with different experiments
# - data/xxxx
# callhome, sre, swb2, and swb_cellular datasets
# - data/simu_${simu_outputs}
# simulation mixtures generated with various options
stage=0
# Modify corpus directories
# - callhome_dir
# CALLHOME (LDC2001S97)
# - swb2_phase1_train
# Switchboard-2 Phase 1 (LDC98S75)
# - data_root
# LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07,
# LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09,
# LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08
# - musan_root
# MUSAN corpus (https://www.openslr.org/17/)
callhome_dir=
swb2_phase1_train=
data_root=
musan_root=
# Modify simulated data storage area.
# This script distributes simulated data under these directories
simu_actual_dirs=(
./s05/$USER/diarization-data
./s08/$USER/diarization-data
./s09/$USER/diarization-data
)
# data preparation options
max_jobs_run=4
sad_num_jobs=30
sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3"
sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0"
sad_priors_opts="--sil-scale=0.1"
# simulation options
simu_opts_overlap=yes
simu_opts_num_speaker_array=(1 2 3 4)
simu_opts_sil_scale_array=(2 2 5 9)
simu_opts_rvb_prob=0.5
simu_opts_num_train=100000
simu_opts_min_utts=10
simu_opts_max_utts=20
simu_cmd="run.pl"
train_cmd="run.pl"
random_mixture_cmd="run.pl"
make_mixture_cmd="run.pl"
. parse_options.sh || exit
if [ $stage -le 0 ]; then
echo "prepare kaldi-style datasets"
# Prepare CALLHOME dataset. This will be used to evaluation.
if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spkall \
|| ! validate_data_dir.sh --no-text --no-feats data/callhome2_spkall; then
# imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1
local/make_callhome.sh $callhome_dir data
# Generate two-speaker subsets
for dset in callhome1 callhome2; do
# Extract two-speaker recordings in wav.scp
copy_data_dir.sh data/${dset} data/${dset}_spkall
# Regenerate segments file from fullref.rttm
# $2: recid, $4: start_time, $5: duration, $8: speakerid
awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \
$2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \
data/callhome/fullref.rttm | sort > data/${dset}_spkall/segments
utils/fix_data_dir.sh data/${dset}_spkall
# Speaker ID is '[recid]_[speakerid]
awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \
data/${dset}_spkall/segments > data/${dset}_spkall/utt2spk
utils/fix_data_dir.sh data/${dset}_spkall
# Generate rttm files for scoring
steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \
data/${dset}_spkall/utt2spk data/${dset}_spkall/segments \
data/${dset}_spkall/rttm
utils/data/get_reco2dur.sh data/${dset}_spkall
done
fi
# Prepare a collection of NIST SRE and SWB data. This will be used to train,
if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then
local/make_sre.sh $data_root data
# Prepare SWB for x-vector DNN training.
local/make_swbd2_phase1.pl $swb2_phase1_train \
data/swbd2_phase1_train
local/make_swbd2_phase2.pl $data_root/LDC99S79 \
data/swbd2_phase2_train
local/make_swbd2_phase3.pl $data_root/LDC2002S06 \
data/swbd2_phase3_train
local/make_swbd_cellular1.pl $data_root/LDC2001S13 \
data/swbd_cellular1_train
local/make_swbd_cellular2.pl $data_root/LDC2004S07 \
data/swbd_cellular2_train
# Combine swb and sre data
utils/combine_data.sh data/swb_sre_comb \
data/swbd_cellular1_train data/swbd_cellular2_train \
data/swbd2_phase1_train \
data/swbd2_phase2_train data/swbd2_phase3_train data/sre
fi
# musan data. "back-ground
if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then
local/make_musan.sh $musan_root data
utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg
awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk
utils/fix_data_dir.sh data/musan_noise_bg
fi
# simu rirs 8k
if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then
mkdir -p data/simu_rirs_8k
# if [ ! -e sim_rir_8k.zip ]; then
# wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip
# fi
unzip sim_rir_8k.zip -d data/sim_rir_8k
find $PWD/data/sim_rir_8k -iname "*.wav" \
| awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \
| sort > data/simu_rirs_8k/wav.scp
awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk
utils/fix_data_dir.sh data/simu_rirs_8k
fi
# Automatic segmentation using pretrained SAD model
# it will take one day using 30 CPU jobs:
# make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours
sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a
sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a
if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then
if [ ! -d exp/segmentation_1a ]; then
# wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz
tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz
fi
steps/segmentation/detect_speech_activity.sh \
--nj $sad_num_jobs \
--graph-opts "$sad_graph_opts" \
--transform-probs-opts "$sad_priors_opts" $sad_opts \
data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \
$sad_work_dir/swb_sre_comb || exit 1
fi
# Extract >1.5 sec segments and split into train/valid sets
if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then
copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg
awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments
cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg
fix_data_dir.sh data/swb_sre_comb_seg
utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv
fi
fi
simudir=data/simu
if [ $stage -le 1 ]; then
echo "simulation of mixture"
mkdir -p $simudir/.work
random_mixture_cmd=local/random_mixture.py
make_mixture_cmd=local/make_mixture.py
for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do
simu_opts_num_speaker=${simu_opts_num_speaker_array[i]}
simu_opts_sil_scale=${simu_opts_sil_scale_array[i]}
for dset in swb_sre_tr swb_sre_cv; do
if [ "$dset" == "swb_sre_tr" ]; then
n_mixtures=${simu_opts_num_train}
else
n_mixtures=500
fi
simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures}
# check if you have the simulation
if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then
# random mixture generation
$train_cmd $simudir/.work/random_mixture_$simuid.log \
$random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \
--speech_rvb_probability $simu_opts_rvb_prob \
--sil_scale $simu_opts_sil_scale \
data/$dset data/musan_noise_bg data/simu_rirs_8k \
\> $simudir/.work/mixture_$simuid.scp
nj=64
mkdir -p $simudir/wav/$simuid
# distribute simulated data to $simu_actual_dir
split_scps=
for n in $(seq $nj); do
split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp"
mkdir -p $simudir/.work/data_$simuid.$n
actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n
mkdir -p $actual
ln -nfs $actual $simudir/wav/$simuid/$n
done
utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1
$simu_cmd --max-jobs-run 64 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \
$make_mixture_cmd --rate=8000 \
$simudir/.work/mixture_$simuid.JOB.scp \
$simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB
utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.*
steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \
$simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \
$simudir/data/$simuid/rttm
utils/data/get_reco2dur.sh $simudir/data/$simuid
fi
simuid_concat=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
mkdir -p $simudir/data/$simuid_concat
for f in `ls -F $simudir/data/$simuid | grep -v "/"`; do
cat $simudir/data/$simuid/$f >> $simudir/data/$simuid_concat/$f
done
done
done
fi
if [ $stage -le 3 ]; then
# compose eval/callhome2_spkall
eval_set=data/eval/callhome2_spkall
if ! validate_data_dir.sh --no-text --no-feats $eval_set; then
utils/copy_data_dir.sh data/callhome2_spkall $eval_set
cp data/callhome2_spkall/rttm $eval_set/rttm
awk -v dstdir=wav/eval/callhome2_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome2_spkall/wav.scp > $eval_set/wav.scp
mkdir -p wav/eval/callhome2_spkall
wav-copy scp:data/callhome2_spkall/wav.scp scp:$eval_set/wav.scp
utils/data/get_reco2dur.sh $eval_set
fi
# compose eval/callhome1_spkall
adapt_set=data/eval/callhome1_spkall
if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then
utils/copy_data_dir.sh data/callhome1_spkall $adapt_set
cp data/callhome1_spkall/rttm $adapt_set/rttm
awk -v dstdir=wav/eval/callhome1_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome1_spkall/wav.scp > $adapt_set/wav.scp
mkdir -p wav/eval/callhome1_spkall
wav-copy scp:data/callhome1_spkall/wav.scp scp:$adapt_set/wav.scp
utils/data/get_reco2dur.sh $adapt_set
fi
fi

View File

@ -0,0 +1,117 @@
import argparse
import os
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('root_path', help='raw data path')
args = parser.parse_args()
root_path = args.root_path
work_path = os.path.join(root_path, ".work")
scp_files = os.listdir(work_path)
reco2dur_dict = {}
with open(os.path.join(root_path, 'reco2dur')) as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
reco2dur_dict[parts[0]] = parts[1]
spk2utt_dict = {}
with open(os.path.join(root_path, 'spk2utt')) as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
spk = parts[0]
utts = parts[1:]
for utt in utts:
tmp = utt.split('data')
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
if rec in spk2utt_dict.keys():
spk2utt_dict[rec].append((spk, utt))
else:
spk2utt_dict[rec] = []
spk2utt_dict[rec].append((spk, utt))
segment_dict = {}
with open(os.path.join(root_path, 'segments')) as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if parts[1] in segment_dict.keys():
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
else:
segment_dict[parts[1]] = []
segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
utt2spk_dict = {}
with open(os.path.join(root_path, 'utt2spk')) as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
utt = parts[0]
tmp = utt.split('data')
rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
if rec in utt2spk_dict.keys():
utt2spk_dict[rec].append((parts[0], parts[1]))
else:
utt2spk_dict[rec] = []
utt2spk_dict[rec].append((parts[0], parts[1]))
for file in scp_files:
scp_file = os.path.join(work_path, file)
idx = scp_file.split('.')[-1]
reco2dur_file = os.path.join(work_path, 'reco2dur.{}'.format(str(idx)))
spk2utt_file = os.path.join(work_path, 'spk2utt.{}'.format(str(idx)))
segment_file = os.path.join(work_path, 'segments.{}'.format(str(idx)))
utt2spk_file = os.path.join(work_path, 'utt2spk.{}'.format(str(idx)))
fpp = open(scp_file)
scp_lines = fpp.readlines()
keys = []
for line in scp_lines:
name = line.strip().split()[0]
keys.append(name)
with open(reco2dur_file, 'w') as f:
lines = []
for key in keys:
string = key + ' ' + reco2dur_dict[key]
lines.append(string + '\n')
lines[-1] = lines[-1][:-1]
f.writelines(lines)
with open(spk2utt_file, 'w') as f:
lines = []
for key in keys:
items = spk2utt_dict[key]
for item in items:
string = item[0]
for it in item[1:]:
string += ' '
string += it
lines.append(string + '\n')
lines[-1] = lines[-1][:-1]
f.writelines(lines)
with open(segment_file, 'w') as f:
lines = []
for key in keys:
items = segment_dict[key]
for item in items:
string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2]
lines.append(string + '\n')
lines[-1] = lines[-1][:-1]
f.writelines(lines)
with open(utt2spk_file, 'w') as f:
lines = []
for key in keys:
items = utt2spk_dict[key]
for item in items:
string = item[0] + ' ' + item[1]
lines.append(string + '\n')
lines[-1] = lines[-1][:-1]
f.writelines(lines)
fpp.close()

13
egs/callhome/eend_ola/path.sh Executable file
View File

@ -0,0 +1,13 @@
export FUNASR_DIR=$PWD/../../..
# kaldi-related
export KALDI_ROOT=
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:$PYTHONPATH
export PATH=$FUNASR_DIR/funasr/bin:$PATH

View File

@ -0,0 +1,324 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
count=1
# general configuration
dump_cmd=utils/run.pl
nj=64
# feature configuration
data_dir="./data"
simu_feats_dir=$data_dir/ark_data/dump/simu_data/data
simu_feats_dir_chunk2000=$data_dir/ark_data/dump/simu_data_chunk2000/data
callhome_feats_dir_chunk2000=$data_dir/ark_data/dump/callhome_chunk2000/data
simu_train_dataset=train
simu_valid_dataset=dev
callhome_train_dataset=callhome1_spkall
callhome_valid_dataset=callhome2_spkall
# model average
simu_average_2spkr_start=91
simu_average_2spkr_end=100
simu_average_allspkr_start=16
simu_average_allspkr_end=25
callhome_average_start=91
callhome_average_end=100
exp_dir="."
input_size=345
stage=1
stop_stage=5
# exp tag
tag="exp1"
. local/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
simu_2spkr_diar_config=conf/train_diar_eend_ola_simu_2spkr.yaml
simu_allspkr_diar_config=conf/train_diar_eend_ola_simu_allspkr.yaml
simu_allspkr_chunk2000_diar_config=conf/train_diar_eend_ola_simu_allspkr_chunk2000.yaml
callhome_diar_config=conf/train_diar_eend_ola_callhome_chunk2000.yaml
simu_2spkr_model_dir="baseline_$(basename "${simu_2spkr_diar_config}" .yaml)_${tag}"
simu_allspkr_model_dir="baseline_$(basename "${simu_allspkr_diar_config}" .yaml)_${tag}"
simu_allspkr_chunk2000_model_dir="baseline_$(basename "${simu_allspkr_chunk2000_diar_config}" .yaml)_${tag}"
callhome_model_dir="baseline_$(basename "${callhome_diar_config}" .yaml)_${tag}"
# simulate mixture data for training and inference
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Simulate mixture data for training and inference"
echo "The detail can be found in https://github.com/hitachi-speech/EEND"
echo "Before running this step, you should download and compile kaldi and set KALDI_ROOT in this script and path.sh"
echo "This stage may take a long time, please waiting..."
KALDI_ROOT=
ln -s $KALDI_ROOT/egs/wsj/s5/steps steps
ln -s $KALDI_ROOT/egs/wsj/s5/utils utils
local/run_prepare_shared_eda.sh
fi
# Prepare data for training and inference
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Prepare data for training and inference"
simu_opts_num_speaker_array=(1 2 3 4)
simu_opts_sil_scale_array=(2 2 5 9)
simu_opts_num_train=100000
# for simulated data of chunk500 and chunk2000
for dset in swb_sre_cv swb_sre_tr; do
if [ "$dset" == "swb_sre_tr" ]; then
n_mixtures=${simu_opts_num_train}
dataset=train
else
n_mixtures=500
dataset=dev
fi
simu_data_dir=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures}
mkdir -p ${data_dir}/simu/data/${simu_data_dir}/.work
split_scps=
for n in $(seq $nj); do
split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.scp.$n"
done
utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
python local/split.py ${data_dir}/simu/data/${simu_data_dir}
# for chunk_size=500
output_dir=${data_dir}/ark_data/dump/simu_data/$dataset
mkdir -p $output_dir/.logs
$dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
python local/dump_feature.py \
--data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
--output_dir $output_dir \
--index JOB
mkdir -p ${data_dir}/ark_data/dump/simu_data/data/$dataset
cat ${data_dir}/ark_data/dump/simu_data/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp
cat ${data_dir}/ark_data/dump/simu_data/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp
paste -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp
grep "ns2" ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats.scp > ${data_dir}/ark_data/dump/simu_data/data/$dataset/feats_2spkr.scp
# for chunk_size=2000
output_dir=${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset
mkdir -p $output_dir/.logs
$dump_cmd --max-jobs-run $nj JOB=1:$nj $output_dir/.logs/dump.JOB.log \
python local/dump_feature.py \
--data_dir ${data_dir}/simu/data/${simu_data_dir}/.work \
--output_dir $output_dir \
--index JOB \
--num_frames 2000
mkdir -p ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset
cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/feature.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp
cat ${data_dir}/ark_data/dump/simu_data_chunk2000/$dataset/label.scp.* > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp
paste -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feature.scp <(cut -f2 -d" " ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/label.scp) > ${data_dir}/ark_data/dump/simu_data_chunk2000/data/$dataset/feats.scp
done
# for callhome data
for dset in callhome1_spkall callhome2_spkall; do
find $data_dir/eval/$dset -maxdepth 1 -type f -exec cp {} {}.1 \;
output_dir=${data_dir}/ark_data/dump/callhome_chunk2000/$dset
mkdir -p $output_dir
python local/dump_feature.py \
--data_dir $data_dir/eval/$dset \
--output_dir $output_dir \
--index 1 \
--num_frames 2000
mkdir -p ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset
paste -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/feature.scp.1 <(cut -f2 -d" " ${data_dir}/ark_data/dump/callhome_chunk2000/$dset/label.scp.1) > ${data_dir}/ark_data/dump/callhome_chunk2000/data/$dset/feats.scp
done
fi
# Training on simulated two-speaker data
world_size=$gpu_num
simu_2spkr_ave_id=avg${simu_average_2spkr_start}-${simu_average_2spkr_end}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Training on simulated two-speaker data"
mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}
mkdir -p ${exp_dir}/exp/${simu_2spkr_model_dir}/log
INIT_FILE=${exp_dir}/exp/${simu_2spkr_model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name diar \
--gpu_id $gpu_id \
--use_preprocessor false \
--input_size $input_size \
--data_dir ${simu_feats_dir} \
--train_set ${simu_train_dataset} \
--valid_set ${simu_valid_dataset} \
--data_file_names "feats_2spkr.scp" \
--resume true \
--output_dir ${exp_dir}/exp/${simu_2spkr_model_dir} \
--config $simu_2spkr_diar_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_2spkr_model_dir}/log/train.log.$i 2>&1
} &
done
wait
echo "averaging model parameters into ${exp_dir}/exp/$simu_2spkr_model_dir/$simu_2spkr_ave_id.pb"
models=`eval echo ${exp_dir}/exp/${simu_2spkr_model_dir}/{$simu_average_2spkr_start..$simu_average_2spkr_end}epoch.pb`
python local/model_averaging.py ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb $models
fi
# Training on simulated all-speaker data
world_size=$gpu_num
simu_allspkr_ave_id=avg${simu_average_allspkr_start}-${simu_average_allspkr_end}
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Training on simulated all-speaker data"
mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}
mkdir -p ${exp_dir}/exp/${simu_allspkr_model_dir}/log
INIT_FILE=${exp_dir}/exp/${simu_allspkr_model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name diar \
--gpu_id $gpu_id \
--use_preprocessor false \
--input_size $input_size \
--data_dir ${simu_feats_dir} \
--train_set ${simu_train_dataset} \
--valid_set ${simu_valid_dataset} \
--data_file_names "feats.scp" \
--resume true \
--init_param ${exp_dir}/exp/${simu_2spkr_model_dir}/$simu_2spkr_ave_id.pb \
--output_dir ${exp_dir}/exp/${simu_allspkr_model_dir} \
--config $simu_allspkr_diar_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_model_dir}/log/train.log.$i 2>&1
} &
done
wait
echo "averaging model parameters into ${exp_dir}/exp/$simu_allspkr_model_dir/$simu_allspkr_ave_id.pb"
models=`eval echo ${exp_dir}/exp/${simu_allspkr_model_dir}/{$simu_average_allspkr_start..$simu_average_allspkr_end}epoch.pb`
python local/model_averaging.py ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb $models
fi
# Training on simulated all-speaker data with chunk_size 2000
world_size=$gpu_num
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training on simulated all-speaker data with chunk_size 2000"
mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}
mkdir -p ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log
INIT_FILE=${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name diar \
--gpu_id $gpu_id \
--use_preprocessor false \
--input_size $input_size \
--data_dir ${simu_feats_dir_chunk2000} \
--train_set ${simu_train_dataset} \
--valid_set ${simu_valid_dataset} \
--data_file_names "feats.scp" \
--resume true \
--init_param ${exp_dir}/exp/${simu_allspkr_model_dir}/$simu_allspkr_ave_id.pb \
--output_dir ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir} \
--config $simu_allspkr_chunk2000_diar_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
# Training on callhome all-speaker data with chunk_size 2000
world_size=$gpu_num
callhome_ave_id=avg${callhome_average_start}-${callhome_average_end}
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: Training on callhome all-speaker data with chunk_size 2000"
mkdir -p ${exp_dir}/exp/${callhome_model_dir}
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/log
INIT_FILE=${exp_dir}/exp/${callhome_model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name diar \
--gpu_id $gpu_id \
--use_preprocessor false \
--input_size $input_size \
--data_dir ${callhome_feats_dir_chunk2000} \
--train_set ${callhome_train_dataset} \
--valid_set ${callhome_valid_dataset} \
--data_file_names "feats.scp" \
--resume true \
--init_param ${exp_dir}/exp/${simu_allspkr_chunk2000_model_dir}/1epoch.pb \
--output_dir ${exp_dir}/exp/${callhome_model_dir} \
--config $callhome_diar_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${callhome_model_dir}/log/train.log.$i 2>&1
} &
done
wait
echo "averaging model parameters into ${exp_dir}/exp/$callhome_model_dir/$callhome_ave_id.pb"
models=`eval echo ${exp_dir}/exp/${callhome_model_dir}/{$callhome_average_start..$callhome_average_end}epoch.pb`
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
fi
# inference and compute DER
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Inference"
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
--config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
--model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
--wav_scp_file $data_dir/eval/callhome2_spkall/wav.scp \
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
md-eval.pl -c 0.25 \
-r ${data_dir}/eval/${callhome_valid_dataset}/rttm \
-s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
fi

2739
egs/callhome/sond/sond.yaml Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,97 @@
from funasr.bin.diar_inference_launch import inference_launch
import os
def test_fbank_cpu_infer():
diar_config_path = "sond_fbank.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=0,
log_level="INFO",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_fbank_gpu_infer():
diar_config_path = "sond_fbank.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="INFO",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_wav_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_wav.scp", "speech", "sound"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_without_profile_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pb"
output_dir = "./outputs"
raw_inputs = [[
"data/unit_test/raw_inputs/record.wav",
"data/unit_test/raw_inputs/spk1.wav",
"data/unit_test/raw_inputs/spk2.wav",
"data/unit_test/raw_inputs/spk3.wav",
"data/unit_test/raw_inputs/spk4.wav"
]]
pipeline = inference_launch(
mode="sond_demo",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
param_dict={},
)
results = pipeline(raw_inputs=raw_inputs)
print(results)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
test_fbank_cpu_infer()
# test_fbank_gpu_infer()
# test_wav_gpu_infer()
# test_without_profile_gpu_infer()

View File

@ -86,6 +86,12 @@ def build_args(args, parser, extra_task_params):
from funasr.build_utils.build_diar_model import class_choices_list
for class_choices in class_choices_list:
class_choices.add_arguments(task_parser)
task_parser.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
elif args.task_name == "sv":
from funasr.build_utils.build_sv_model import class_choices_list

View File

@ -4,8 +4,21 @@ from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFac
def build_dataloader(args):
if args.dataset_type == "small":
train_iter_factory = SequenceIterFactory(args, mode="train")
valid_iter_factory = SequenceIterFactory(args, mode="valid")
if args.task_name == "diar" and args.model == "eend_ola":
from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
train_iter_factory = EENDOLADataLoader(
data_file=args.train_data_path_and_name_and_type[0][0],
batch_size=args.dataset_conf["batch_conf"]["batch_size"],
num_workers=args.dataset_conf["num_workers"],
shuffle=True)
valid_iter_factory = EENDOLADataLoader(
data_file=args.valid_data_path_and_name_and_type[0][0],
batch_size=args.dataset_conf["batch_conf"]["batch_size"],
num_workers=0,
shuffle=False)
else:
train_iter_factory = SequenceIterFactory(args, mode="train")
valid_iter_factory = SequenceIterFactory(args, mode="valid")
elif args.dataset_type == "large":
train_iter_factory = LargeDataLoader(args, mode="train")
valid_iter_factory = LargeDataLoader(args, mode="valid")

View File

@ -192,18 +192,22 @@ class_choices_list = [
def build_diar_model(args):
# token_list
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
if args.token_list is not None:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
token_list = None
vocab_size = None
# frontend
if args.input_size is None:
@ -212,16 +216,14 @@ def build_diar_model(args):
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
encoder = encoder_class(**args.encoder_conf)
if args.model == "sond":
# data augmentation for spectrogram
@ -294,7 +296,7 @@ def build_diar_model(args):
**args.model_conf,
)
elif args.model_name == "eend_ola":
elif args.model == "eend_ola":
# encoder-decoder attractor
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)

View File

@ -57,7 +57,7 @@ class SequenceIterFactory(AbsIterFactory):
data_path_and_name_and_type,
preprocess=preprocess_fn,
dest_sample_rate=dest_sample_rate,
speed_perturb=args.speed_perturb if mode=="train" else None,
speed_perturb=args.speed_perturb if mode == "train" else None,
)
# sampler
@ -84,7 +84,7 @@ class SequenceIterFactory(AbsIterFactory):
args.max_update = len(bs_list) * args.max_epoch
logging.info("Max update: {}".format(args.max_update))
if args.distributed and mode=="train":
if args.distributed and mode == "train":
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
for batch in batches:

View File

@ -1,21 +1,20 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.base_model import FunASRModel
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@ -33,12 +32,35 @@ def pad_attractor(att, max_n_speakers):
return att
def pad_labels(ts, out_size):
for i, t in enumerate(ts):
if t.shape[1] < out_size:
ts[i] = F.pad(
t,
(0, out_size - t.shape[1], 0, 0),
mode='constant',
value=0.
)
return ts
def pad_results(ys, out_size):
ys_padded = []
for i, y in enumerate(ys):
if y.shape[1] < out_size:
ys_padded.append(
torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
else:
ys_padded.append(y)
return ys_padded
class DiarEENDOLAModel(FunASRModel):
"""EEND-OLA diarization model"""
def __init__(
self,
frontend: WavFrontendMel23,
frontend: Optional[WavFrontendMel23],
encoder: EENDOLATransformerEncoder,
encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
@ -47,11 +69,10 @@ class DiarEENDOLAModel(FunASRModel):
mapping_dict=None,
**kwargs,
):
super().__init__()
self.frontend = frontend
self.enc = encoder
self.eda = encoder_decoder_attractor
self.encoder_decoder_attractor = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
@ -74,7 +95,8 @@ class DiarEENDOLAModel(FunASRModel):
def forward_post_net(self, logits, ilens):
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
enforce_sorted=False)
outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@ -83,95 +105,45 @@ class DiarEENDOLAModel(FunASRModel):
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: List[torch.Tensor],
speaker_labels: List[torch.Tensor],
orders: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
batch_size = len(speech)
# for data-parallel
text = text[:, : text_lengths.max()]
# Encoder
encoder_out = self.forward_encoder(speech, speech_lengths)
# 1. Encoder
encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Encoder-decoder attractor
attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
speaker_labels_lengths)
speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
# pit loss
pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
# pse loss
with torch.no_grad():
power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["pse_loss"] = pse_loss.detach()
stats["pit_loss"] = pit_loss.detach()
stats["attractor_loss"] = attractor_loss.detach()
stats["batch_size"] = batch_size
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
@ -182,21 +154,20 @@ class DiarEENDOLAModel(FunASRModel):
def estimate_sequential(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
n_speakers: int = None,
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
attractors, probs = self.eda.estimate(
attractors, probs = self.encoder_decoder_attractor.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
attractors, probs = self.eda.estimate(emb)
attractors, probs = self.encoder_decoder_attractor.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:

View File

@ -0,0 +1,57 @@
import logging
import kaldiio
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
def custom_collate(batch):
keys, speech, speaker_labels, orders = zip(*batch)
speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
batch = dict(speech=speech,
speaker_labels=speaker_labels,
orders=orders)
return keys, batch
class EENDOLADataset(Dataset):
def __init__(
self,
data_file,
):
self.data_file = data_file
with open(data_file) as f:
lines = f.readlines()
self.samples = [line.strip().split() for line in lines]
logging.info("total samples: {}".format(len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
key, speech_path, speaker_label_path = self.samples[idx]
speech = kaldiio.load_mat(speech_path)
speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
order = np.arange(speech.shape[0])
np.random.shuffle(order)
return key, speech, speaker_label, order
class EENDOLADataLoader():
def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
dataset = EENDOLADataset(data_file)
self.data_loader = DataLoader(dataset,
batch_size=batch_size,
collate_fn=custom_collate,
shuffle=shuffle,
num_workers=num_workers)
def build_iter(self, epoch):
return self.data_loader

View File

@ -91,6 +91,7 @@ class EENDOLATransformerEncoder(nn.Module):
dropout_rate: float = 0.1,
use_pos_emb: bool = False):
super(EENDOLATransformerEncoder, self).__init__()
self.linear_in = nn.Linear(idim, n_units)
self.lnorm_in = nn.LayerNorm(n_units)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout_rate)
@ -104,25 +105,10 @@ class EENDOLATransformerEncoder(nn.Module):
setattr(self, '{}{:d}'.format("ff_", i),
PositionwiseFeedForward(n_units, e_units, dropout_rate))
self.lnorm_out = nn.LayerNorm(n_units)
if use_pos_emb:
self.pos_enc = torch.nn.Sequential(
torch.nn.Linear(idim, n_units),
torch.nn.LayerNorm(n_units),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
PositionalEncoding(n_units, dropout_rate),
)
else:
self.linear_in = nn.Linear(idim, n_units)
self.pos_enc = None
def __call__(self, x, x_mask=None):
BT_size = x.shape[0] * x.shape[1]
if self.pos_enc is not None:
e = self.pos_enc(x)
e = e.view(BT_size, -1)
else:
e = self.linear_in(x.reshape(BT_size, -1))
e = self.linear_in(x.reshape(BT_size, -1))
for i in range(self.n_layers):
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
@ -130,4 +116,4 @@ class EENDOLATransformerEncoder(nn.Module):
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
e = e + self.dropout(s)
return self.lnorm_out(e)
return self.lnorm_out(e)

View File

@ -0,0 +1,286 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import numpy as np
import librosa
def get_input_dim(
frame_size,
context_size,
transform_type,
):
if transform_type.startswith('logmel23'):
frame_size = 23
elif transform_type.startswith('logmel'):
frame_size = 40
else:
fft_size = 1 << (frame_size - 1).bit_length()
frame_size = int(fft_size / 2) + 1
input_dim = (2 * context_size + 1) * frame_size
return input_dim
def transform(
Y,
transform_type=None,
dtype=np.float32):
""" Transform STFT feature
Args:
Y: STFT
(n_frames, n_bins)-shaped np.complex array
transform_type:
None, "log"
dtype: output data type
np.float32 is expected
Returns:
Y (numpy.array): transformed feature
"""
Y = np.abs(Y)
if not transform_type:
pass
elif transform_type == 'log':
Y = np.log(np.maximum(Y, 1e-10))
elif transform_type == 'logmel':
n_fft = 2 * (Y.shape[1] - 1)
sr = 16000
n_mels = 40
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
elif transform_type == 'logmel23':
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
elif transform_type == 'logmel23_mn':
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
elif transform_type == 'logmel23_swn':
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
# b = np.ones(300)/300
# mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
# simple 2-means based threshoding for mean calculation
powers = np.sum(Y, axis=1)
th = (np.max(powers) + np.min(powers)) / 2.0
for i in range(10):
th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
mean = np.mean(Y[powers > th, :], axis=0)
Y = Y - mean
elif transform_type == 'logmel23_mvn':
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
std = np.maximum(np.std(Y, axis=0), 1e-10)
Y = Y / std
else:
raise ValueError('Unknown transform_type: %s' % transform_type)
return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
""" Frame subsampling
"""
Y_ss = Y[::subsampling]
T_ss = T[::subsampling]
return Y_ss, T_ss
def splice(Y, context_size=0):
""" Frame splicing
Args:
Y: feature
(n_frames, n_featdim)-shaped numpy array
context_size:
number of frames concatenated on left-side
if context_size = 5, 11 frames are concatenated.
Returns:
Y_spliced: spliced feature
(n_frames, n_featdim * (2 * context_size + 1))-shaped
"""
Y_pad = np.pad(
Y,
[(context_size, context_size), (0, 0)],
'constant')
Y_spliced = np.lib.stride_tricks.as_strided(
np.ascontiguousarray(Y_pad),
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
return Y_spliced
def stft(
data,
frame_size=1024,
frame_shift=256):
""" Compute STFT features
Args:
data: audio signal
(n_samples,)-shaped np.float32 array
frame_size: number of samples in a frame (must be a power of two)
frame_shift: number of samples between frames
Returns:
stft: STFT frames
(n_frames, n_bins)-shaped np.complex64 array
"""
# round up to nearest power of 2
fft_size = 1 << (frame_size - 1).bit_length()
# HACK: The last frame is ommited
# as librosa.stft produces such an excessive frame
if len(data) % frame_shift == 0:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T[:-1]
else:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T
def _count_frames(data_len, size, shift):
# HACK: Assuming librosa.stft(..., center=True)
n_frames = 1 + int(data_len / shift)
if data_len % shift == 0:
n_frames = n_frames - 1
return n_frames
def get_frame_labels(
kaldi_obj,
rec,
start=0,
end=None,
frame_size=1024,
frame_shift=256,
n_speakers=None):
""" Get frame-aligned labels of given recording
Args:
kaldi_obj (KaldiData)
rec (str): recording id
start (int): start frame index
end (int): end frame index
None means the last frame of recording
frame_size (int): number of frames in a frame
frame_shift (int): number of shift samples
n_speakers (int): number of speakers
if None, the value is given from data
Returns:
T: label
(n_frames, n_speakers)-shaped np.int32 array
"""
filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
speakers = np.unique(
[kaldi_obj.utt2spk[seg['utt']] for seg
in filtered_segments]).tolist()
if n_speakers is None:
n_speakers = len(speakers)
es = end * frame_shift if end is not None else None
data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es)
n_frames = _count_frames(len(data), frame_size, frame_shift)
T = np.zeros((n_frames, n_speakers), dtype=np.int32)
if end is None:
end = n_frames
for seg in filtered_segments:
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
start_frame = np.rint(
seg['st'] * rate / frame_shift).astype(int)
end_frame = np.rint(
seg['et'] * rate / frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
if start < end_frame and end_frame <= end:
rel_end = end_frame - start
if rel_start is not None or rel_end is not None:
T[rel_start:rel_end, speaker_index] = 1
return T
def get_labeledSTFT(
kaldi_obj,
rec, start, end, frame_size, frame_shift,
n_speakers=None,
use_speaker_id=False):
""" Extracts STFT and corresponding labels
Extracts STFT and corresponding diarization labels for
given recording id and start/end times
Args:
kaldi_obj (KaldiData)
rec (str): recording id
start (int): start frame index
end (int): end frame index
frame_size (int): number of samples in a frame
frame_shift (int): number of shift samples
n_speakers (int): number of speakers
if None, the value is given from data
Returns:
Y: STFT
(n_frames, n_bins)-shaped np.complex64 array,
T: label
(n_frmaes, n_speakers)-shaped np.int32 array.
"""
data, rate = kaldi_obj.load_wav(
rec, start * frame_shift, end * frame_shift)
Y = stft(data, frame_size, frame_shift)
filtered_segments = kaldi_obj.segments[rec]
# filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
speakers = np.unique(
[kaldi_obj.utt2spk[seg['utt']] for seg
in filtered_segments]).tolist()
if n_speakers is None:
n_speakers = len(speakers)
T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
if use_speaker_id:
all_speakers = sorted(kaldi_obj.spk2utt.keys())
S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
for seg in filtered_segments:
speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
if use_speaker_id:
all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
start_frame = np.rint(
seg['st'] * rate / frame_shift).astype(int)
end_frame = np.rint(
seg['et'] * rate / frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
if start < end_frame and end_frame <= end:
rel_end = end_frame - start
if rel_start is not None or rel_end is not None:
T[rel_start:rel_end, speaker_index] = 1
if use_speaker_id:
S[rel_start:rel_end, all_speaker_index] = 1
if use_speaker_id:
return Y, T, S
else:
return Y, T

View File

@ -0,0 +1,162 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This library provides utilities for kaldi-style data directory.
from __future__ import print_function
import os
import sys
import numpy as np
import subprocess
import soundfile as sf
import io
from functools import lru_cache
def load_segments(segments_file):
""" load segments file as array """
if not os.path.exists(segments_file):
return None
return np.loadtxt(
segments_file,
dtype=[('utt', 'object'),
('rec', 'object'),
('st', 'f'),
('et', 'f')],
ndmin=1)
def load_segments_hash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
ret[utt] = (rec, float(st), float(et))
return ret
def load_segments_rechash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
if rec not in ret:
ret[rec] = []
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
return ret
def load_wav_scp(wav_scp_file):
""" return dictionary { rec: wav_rxfilename } """
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
return {x[0]: x[1] for x in lines}
@lru_cache(maxsize=1)
def load_wav(wav_rxfilename, start=0, end=None):
""" This function reads audio file and return data in numpy.float32 array.
"lru_cache" holds recently loaded audio so that can be called
many times on the same audio file.
OPTIMIZE: controls lru_cache size for random access,
considering memory size
"""
if wav_rxfilename.endswith('|'):
# input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
dtype='float32')
# cannot seek
data = data[start:end]
elif wav_rxfilename == '-':
# stdin
data, samplerate = sf.read(sys.stdin, dtype='float32')
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
return data, samplerate
def load_utt2spk(utt2spk_file):
""" returns dictionary { uttid: spkid } """
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
return {x[0]: x[1] for x in lines}
def load_spk2utt(spk2utt_file):
""" returns dictionary { spkid: list of uttids } """
if not os.path.exists(spk2utt_file):
return None
lines = [line.strip().split() for line in open(spk2utt_file)]
return {x[0]: x[1:] for x in lines}
def load_reco2dur(reco2dur_file):
""" returns dictionary { recid: duration } """
if not os.path.exists(reco2dur_file):
return None
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
return {x[0]: float(x[1]) for x in lines}
def process_wav(wav_rxfilename, process):
""" This function returns preprocessed wav_rxfilename
Args:
wav_rxfilename: input
process: command which can be connected via pipe,
use stdin and stdout
Returns:
wav_rxfilename: output piped command
"""
if wav_rxfilename.endswith('|'):
# input piped command
return wav_rxfilename + process + "|"
else:
# stdin "-" or normal file
return "cat {} | {} |".format(wav_rxfilename, process)
def extract_segments(wavs, segments=None):
""" This function returns generator of segmented audio as
(utterance id, numpy.float32 array)
TODO?: sampling rate is not converted.
"""
if segments is not None:
# segments should be sorted by rec-id
for seg in segments:
wav = wavs[seg['rec']]
data, samplerate = load_wav(wav)
st_sample = np.rint(seg['st'] * samplerate).astype(int)
et_sample = np.rint(seg['et'] * samplerate).astype(int)
yield seg['utt'], data[st_sample:et_sample]
else:
# segments file not found,
# wav.scp is used as segmented audio list
for rec in wavs:
data, samplerate = load_wav(wavs[rec])
yield rec, data
class KaldiData:
def __init__(self, data_dir):
self.data_dir = data_dir
self.segments = load_segments_rechash(
os.path.join(self.data_dir, 'segments'))
self.utt2spk = load_utt2spk(
os.path.join(self.data_dir, 'utt2spk'))
self.wavs = load_wav_scp(
os.path.join(self.data_dir, 'wav.scp'))
self.reco2dur = load_reco2dur(
os.path.join(self.data_dir, 'reco2dur'))
self.spk2utt = load_spk2utt(
os.path.join(self.data_dir, 'spk2utt'))
def load_wav(self, recid, start=0, end=None):
data, rate = load_wav(
self.wavs[recid], start, end)
return data, rate

View File

@ -1,11 +1,10 @@
import numpy as np
import torch
import torch.nn.functional as F
from itertools import permutations
from torch import nn
from scipy.optimize import linear_sum_assignment
def standard_loss(ys, ts, label_delay=0):
def standard_loss(ys, ts):
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
loss = torch.sum(torch.stack(losses))
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
@ -13,55 +12,29 @@ def standard_loss(ys, ts, label_delay=0):
return loss
def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
max_n_speakers = ts[0].shape[1]
olens = [y.shape[0] for y in ys]
ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
def fast_batch_pit_n_speaker_loss(ys, ts):
with torch.no_grad():
bs = len(ys)
indices = []
for b in range(bs):
y = ys[b].transpose(0, 1)
t = ts[b].transpose(0, 1)
C, _ = t.shape
y = y[:, None, :].repeat(1, C, 1)
t = t[None, :, :].repeat(C, 1, 1)
bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
C = bce_loss.cpu()
indices.append(linear_sum_assignment(C))
labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
losses = []
for shift in range(max_n_speakers):
ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
if ys_mask is not None:
loss = loss * ys_mask
loss = torch.sum(loss, dim=1)
losses.append(loss)
losses = torch.stack(losses, dim=2)
return labels_perm
perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
perms = torch.from_numpy(perms).to(losses.device)
y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
losses_perm = []
for t_ind in t_inds:
losses_perm.append(
torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
losses_perm = torch.stack(losses_perm, dim=1)
def select_perm_indices(num, max_num):
perms = list(permutations(range(max_num)))
sub_perms = list(permutations(range(num)))
return [
[x[:num] for x in perms].index(perm)
for perm in sub_perms]
masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
for i, t in enumerate(ts):
n_speakers = n_speakers_list[i]
indices = select_perm_indices(n_speakers, max_n_speakers)
masks[i, indices] = 0
losses_perm += masks
min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
min_loss = min_loss / n_frames
min_indices = torch.argmin(losses_perm, dim=1)
labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
return min_loss, labels_perm
def cal_power_loss(logits, power_ts):
losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
zip(logits, power_ts)]
loss = torch.sum(torch.stack(losses))
n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
power_ts[0].device)
loss = loss / n_frames
return loss

View File

@ -196,12 +196,16 @@ def generate_data_list(args, data_dir, dataset, nj=64):
def prepare_data(args, distributed_option):
distributed = distributed_option.distributed
data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
file_names = args.data_file_names.split(",")
batch_type = args.dataset_conf["batch_conf"]["batch_type"]
if not distributed or distributed_option.dist_rank == 0:
if hasattr(args, "filter_input") and args.filter_input:
filter_wav_text(args.data_dir, args.train_set)
filter_wav_text(args.data_dir, args.valid_set)
if args.dataset_type == "small":
if args.dataset_type == "small" and batch_type != "unsorted":
calc_shape(args, args.train_set)
calc_shape(args, args.valid_set)
@ -209,9 +213,6 @@ def prepare_data(args, distributed_option):
generate_data_list(args, args.data_dir, args.train_set)
generate_data_list(args, args.data_dir, args.valid_set)
data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
file_names = args.data_file_names.split(",")
print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names))
assert len(data_names) == len(data_types) == len(file_names)
if args.dataset_type == "small":